/* viterbi.c - routines for running the Viterbi algorithm on an HMM */

#include "general.h"

path_head *find_vit_path(shead *data, state *initial, int details, int punc_trans, int read_label, float trans_weight) {
    static char rname[]="find_vit_path";
    state *dest, *source, *max_state, **sequence, *index;
    trans *current_trans;
    sdata *current_data;
    delta **deltas, *new_delta, *last_delta, *current_delta, *current_max, *next_delta;
    path *last_path_state, *current_path_state;
    path_head *max_path;
    char word[1000], prev_word[1000];
    static char null_word[] = "(null)";
    int num_steps, found, word_id, allow_transition;
    int i,j;  /* counters for states */
    int current_step, t;    /* counter for steps */
    double trans_lprob, emission_lprob, max_lprob, value, *ltprob, *leprob;

    if (data == NULL) return(NULL);
    if (data->string == NULL) return(NULL);

    max_path = (path_head *) kmalloc(sizeof(path_head));

    /* Find the number of symbols in the data string */
    num_steps = count_num_symbols(data);

    /* Initialize data structures with enough space for number of states and steps */
    /* At each step we'll keep a linked list of possible state transitions */
    deltas = (delta **) kmalloc((num_steps + 1) * sizeof(delta *));
    sequence = (state **) kmalloc((num_steps + 1) * sizeof(state *));
    ltprob = (double *) kmalloc((num_steps+1) * sizeof(double));
    leprob = (double *) kmalloc((num_steps+1) * sizeof(double));
    for (i = 0; i <= num_steps; i++) {
        deltas[i] = NULL;
        sequence[i] = NULL;
        ltprob[i] = 0.0;
        leprob[i] = 0.0;
    }

    /* Step 1 - Initialization */
    /* ----------------------- */

    /* Get first word in data string */
    current_step = 1;
    current_data = data->string;
    strcpy(prev_word,current_data->word);
    get_word(prev_word, word);

    /* Find non-zero transitions out of initial state, and which states could emit word */
    current_trans = initial->out;
    while (current_trans != NULL) {
        dest = current_trans->dest;

        /* If constrained search, make sure labels match */
        if (read_label) {
            if (strcmp(dest->label, current_data->label)) {
                current_trans = current_trans->next_source;
                continue;
            }
        }

        trans_lprob = log(current_trans->prob);
        emission_lprob = get_emission_lprob(dest, word);

        if (details) fprintf(stderr, "emission lprob of word %s from state %d = %f\n", word, dest->id, emission_lprob); 

        if (emission_lprob < no_prob) { /* Hence prob exists */

            /* Add another instance to delta structure */
            new_delta = (delta *) kmalloc(sizeof(delta));
            new_delta->dest = dest;
            new_delta->lprob = (trans_weight * trans_lprob) + emission_lprob;
            new_delta->ltprob = trans_lprob;
            new_delta->leprob = emission_lprob;
            new_delta->source = initial;
            new_delta->duration = 1;
            new_delta->next = NULL;
            if (deltas[current_step] == NULL) 
	      deltas[current_step] = new_delta;
            else { 
                last_delta->next = new_delta;
            }
            last_delta = new_delta;
        }
        current_trans = current_trans->next_source;
    }

    if (details) {
        fprintf(stderr, "Current step = %d\n", current_step);
        last_delta = deltas[current_step];
        while (last_delta != NULL) {
            fprintf(stderr, "dest = %d, source = %d, lprob = %f\n", last_delta->dest->id, last_delta->source->id, last_delta->lprob);
            last_delta = last_delta->next;
        }
    } 

    /* Step 2 - Recursion */
    /* ------------------ */

    /* Get next word in data string */
    current_data = current_data->next;
    while (current_data != NULL) {
        current_step++;
        allow_transition = 1;
        if (punc_trans) { 
            allow_transition = ends_in_punc(prev_word);
            if (DEBUG) fprintf(stderr, "Prev word = %s, allow_transitions = %d\n", prev_word, allow_transition);
        }
        strcpy(prev_word,current_data->word);
        get_word(prev_word, word);

        if (details) fprintf(stderr, "Current step (1) = %d, word = %s, allow_transition = %d\n", current_step, word, allow_transition);

        /* Get states from each delta from previous step, and look at their successors */
        current_delta = deltas[current_step-1];
        while (current_delta != NULL) {
            source = current_delta->dest;
            current_trans = source->out;
            while (current_trans != NULL) {
                dest = current_trans->dest;

                /* If constrained search, make sure labels match */
                if (read_label) {
                    if (strcmp(dest->label, current_data->label)) {
                        current_trans = current_trans->next_source;
                        continue;
                    }
                }

                /* If state transitions are only allowed after words ending in punctuation (.,:!)
			if previous word does not end in punctuation, only consider transitions back to the
			current state */
                if (punc_trans && !allow_transition) {
                    if (dest->id != source->id) {
                        current_trans = current_trans->next_source;
                        continue;
                    }
                }

                /* If path has not remained in the source state as long as required by the state
			duration setting, only consider transitions back to the current state */
                if (current_delta->duration < source->duration) {
                    if (dest->id != source->id) {
                        current_trans = current_trans->next_source;
                        continue;
                    }
                }

                trans_lprob = log(current_trans->prob);
                emission_lprob = get_emission_lprob(dest, word);

                if (details) fprintf(stderr, "Looking at transition from %d to %d, trans lprob = %f, emission lprob = %f\n", source->id, dest->id, trans_lprob, emission_lprob);

                if (emission_lprob < no_prob) { /* hence, prob exists */

                   /* Want to maximize this value */
                   value = current_delta->lprob + (trans_weight * trans_lprob);

                   /* See if a value has already been recorded for the destination state */
                   /* We only want to keep one possible path into the destination state -
                        the one with the highest probability */

                   current_max = deltas[current_step]; 
                   found = 0;
                   while (current_max != NULL && !found) {
                       if (current_max->dest == dest) {
                           found = 1;
                           if (value > current_max->value) {
                               current_max->value = value;
                               current_max->source = source;
                               current_max->lprob = value + emission_lprob;
                               current_max->ltprob = trans_lprob;
                               current_max->leprob = emission_lprob;
                               if (dest->id == source->id) current_max->duration = current_delta->duration + 1;
                               else current_max->duration = 1;
                           }
                       }
                       current_max = current_max->next;
                   }
                   if (!found) {
                        /* Add another instance to delta structure */
                        new_delta = (delta *) kmalloc(sizeof(delta));
                        new_delta->dest = dest;
                        if (dest->id == source->id) new_delta->duration = current_delta->duration + 1;
                        else new_delta->duration = 1;
                        new_delta->value = value;
                        new_delta->lprob = value + emission_lprob;
                        new_delta->ltprob = trans_lprob;
                        new_delta->leprob = emission_lprob;
                        new_delta->source = source;
                        new_delta->next = NULL;
                        if (deltas[current_step] == NULL) 
			  deltas[current_step] = new_delta;
                        else {
                            last_delta->next = new_delta;
                        }
                        last_delta = new_delta;
                   }
                }
                /* else, current dest state cannot emit current word */
                current_trans = current_trans->next_source;
            }
            current_delta = current_delta->next;
        }

        if (details) {
            last_delta = deltas[current_step];
            while (last_delta != NULL) {
                fprintf(stderr, "dest = %d, source = %d, lprob = %f\n", last_delta->dest->id, last_delta->source->id, last_delta->lprob);
                last_delta = last_delta->next;
            }
        }

        current_data = current_data->next;
    }

    if (details) fprintf(stderr, "Finished reading in data string, and current step = %d\n", current_step);

    /* Need to compute deltas into end state - no emission probs */

    if (details) fprintf(stderr, "Transitioning into end state...\n");

    current_step++;
    current_delta = deltas[current_step-1];
    while (current_delta != NULL) {
        source = current_delta->dest;
        current_trans = source->out;  
        while (current_trans != NULL) {
            dest = current_trans->dest;

            /* Only consider transitions into the end state */
            if (!strcmp(dest->label,"end") || !strcmp(dest->label,"e")) { 
                trans_lprob = log(current_trans->prob);

                /* Want to maximize this value */
                value = current_delta->lprob + (trans_weight * trans_lprob);

                /* See if a value has already been recorded for the destination state */
                /* We only want to keep one possible path into the destination state -
                    	the one with the highest probability */
                current_max = deltas[current_step];
                found = 0;
                while (current_max != NULL && !found) {
                    if (current_max->dest == dest) {
                        found = 1;
                        if (value > current_max->value) {
                            current_max->value = value;
                            current_max->source = source;
                            current_max->lprob = value;
                            current_max->ltprob = trans_lprob;
                            current_max->leprob = 0.0;
                        }
                    }
                    current_max = current_max->next;
                }
                if (!found) {
                    /* Add another instance to delta structure */
                    new_delta = (delta *) kmalloc(sizeof(delta));
                    new_delta->dest = dest;
                    if (dest->id == source->id) new_delta->duration = current_delta->duration + 1;
                    else new_delta->duration = 1;
                    new_delta->value = value;
                    new_delta->lprob = value;
                    new_delta->ltprob = trans_lprob;
                    new_delta->leprob = 0.0;
                    new_delta->source = source;
                    new_delta->next = NULL;
                    if (deltas[current_step] == NULL) deltas[current_step] = new_delta;
                    else {
                        last_delta->next = new_delta;
                    }
                    last_delta = new_delta;
                }
            }
            current_trans = current_trans->next_source;
        }
        current_delta = current_delta->next;
    }
    if (details) {
        fprintf(stderr, "Current step (2) = %d\n", current_step);
        last_delta = deltas[current_step];
        while (last_delta != NULL) {
            fprintf(stderr, "dest = %d, source = %d, lprob = %f\n", last_delta->dest->id, last_delta->source->id, last_delta->lprob);
            last_delta = last_delta->next;
        }
    }

    if (current_step != num_steps) 
        quit(-1,"%s: current step (%d) does not equal total number of steps (%d)!\n", rname, current_step, num_steps);

    /* Step 3 - Termination */
    /* -------------------- */

    /* Find the maximum of all the deltas -
       there should actually be only one delta for the end state */

    current_delta = deltas[current_step];

    if (current_delta == NULL) {
        free_deltas(deltas, num_steps);
        free(sequence);
        free(ltprob);
        free(leprob);
        free(max_path);
        return(NULL);
    }

    max_lprob = 1.0;
    max_state = NULL;
    while (current_delta != NULL) {
        if (max_lprob > 0) {
            max_lprob = current_delta->lprob;
            max_state = current_delta->dest;
        }
        else if (current_delta->lprob > max_lprob) {
            max_lprob = current_delta->lprob;
            max_state = current_delta->dest;
        }
        current_delta = current_delta->next;
    }

    /* See if search failed - return NULL path if so */
    if (max_state == NULL) {
        free_deltas(deltas, num_steps);
        free(sequence);
        free(ltprob);
        free(leprob);
        free(max_path);
        return(NULL);
    }
    else if (!(!strcmp(max_state->label,"end") || !strcmp(max_state->label,"e"))) {
        free_deltas(deltas, num_steps);
        free(sequence);
        free(ltprob);
        free(leprob);
        free(max_path);
        return(NULL);
    }

    max_path->lprob = max_lprob;
    sequence[num_steps] = max_state;
   
    /* Step 4 - Backtracking */
    /* --------------------- */
    for (t = num_steps - 1; t >= 0; t--) {
        index = sequence[t+1];
        current_delta = deltas[t+1];
        found = 0;
        while (current_delta != NULL) {
            if (current_delta->dest == index) {
                sequence[t] = current_delta->source;
                ltprob[t+1] = current_delta->ltprob;
                leprob[t+1] = current_delta->leprob;
                if (details) fprintf(stderr, "backtrack vit path: from %d to %d\n", current_delta->dest->id, current_delta->source->id);
                found = 1;
                break;
            }
            current_delta = current_delta->next;
        }
        if (!found) quit(-1, "%s: did not find matching state in Viterbi path!\n", rname);
    }
    if (details) fprintf(stderr, "Finished backtracking...\n");

    /* Copy state sequence to path structure */
    last_path_state = (path *) kmalloc(sizeof(path));
    max_path->first = last_path_state;
    last_path_state->current = sequence[0];
    last_path_state->word = strdup (null_word); 
    last_path_state->ltprob = ltprob[0];
    last_path_state->leprob = leprob[0];
    current_data = data->string;
    for (t = 1; t <= num_steps; t++) {
       current_path_state = (path *) kmalloc(sizeof(path));
       current_path_state->current = sequence[t];
       current_path_state->ltprob = ltprob[t];
       current_path_state->leprob = leprob[t];
       if (t == num_steps || current_data == NULL) {
           current_path_state->word = strdup (null_word);
       }
       else {
           current_path_state->word = strdup (current_data->word);
           current_data = current_data->next;
       }
       last_path_state->next = current_path_state;
       last_path_state = current_path_state;
    }
    last_path_state->next = NULL;

    /* Free delta and sequence memory */
    if (details) fprintf(stderr, "Freeing memory...");
    free_deltas(deltas, num_steps);
    free(sequence);
    free(ltprob);
    free(leprob);
    if (details) fprintf(stderr, "done\n");

    return(max_path);
}

int ends_in_punc(char *prev_word) {
    static char rname[]="ends_in_punc";
    int length;

    length = strlen(prev_word);
    if (prev_word[length-1] == '.' || prev_word[length-1] == ',' || prev_word[length-1] == ':' || prev_word[length-1] == '!') 

/*    if (prev_word[length-1] == '!') */
        return(1);
    else
        return(0);
}

void remove_punc(char *prev_word) {
    static char rname[]="remove_punc";
    int length, i;

    /* Remove punctuation [,:!.] from the end of the word */

    length = strlen(prev_word);
    if (length == 0) quit(-1, "%s: word has no length!!\n", rname);

    /* Check first for non-period punctuation - remove them */
    while (prev_word[length-1] == ',' || 
           prev_word[length-1] == ':' || 
           prev_word[length-1] == '!' || 
           (prev_word[length-1] == '.' && length > 2)) {
        prev_word[length-1] = '\0';
        length--;
        if (length == 0) quit(-1, "%s: word is only made up of punctuation!!\n", rname);
    }

    return;
}

void free_deltas(delta **deltas, int num_steps) {
    static char rname[]="free_deltas";
    int i;
    delta *current_delta, *next_delta;

    for (i = 0; i <= num_steps; i++) {
        current_delta = deltas[i];
        while (current_delta != NULL) {
            next_delta = current_delta->next;
            free(current_delta);
            current_delta = next_delta;
        }
    }
    free(deltas);
}

void free_alphas(alpha **alphas, int num_steps) {
    static char rname[]="free_alphas";
    int i;
    alpha *current_alpha, *next_alpha;

    for (i = 0; i <= num_steps; i++) {
        current_alpha = alphas[i];
        while (current_alpha != NULL) {
            next_alpha = current_alpha->next;
            free(current_alpha);
            current_alpha = next_alpha;
        }
    }
    free(alphas);
}

void print_path(path_head *max_path) {
    static char rname[]="print_path";
    path *current_path;
    double prob;
    int i;   

    if (DEBUG) fprintf(stderr, "Printing out Viterbi path for observation sequence:\n");

    if (max_path == NULL) {
        fprintf(stderr, "No Viterbi path found...\n");
        return;
    }

    i = 0;
    prob = exp(max_path->lprob);
    fprintf(stderr, "Viterbi path log probability = %f (prob = %f)\n", max_path->lprob, prob);
    current_path = max_path->first;
    while (current_path != NULL) {
        i++;

       fprintf(stderr, "Step %d:\tState: %d\tTag: %s\tWord: %s\n", i, current_path->current->id, current_path->current->label, current_path->word); 
       current_path = current_path->next;
    }
    if (DEBUG) fprintf(stderr, "Path string contains %d symbols\n", i);
}


void print_path_to_file(path_head *max_path, FILE *output_file, shead *start, int obs_with_id, int print_state_id, int print_probs) {
    static char rname[]="print_path_to_file";
    path *current_path;
    double prob;
    int i, string_id;

    string_id = start->count;

    if (max_path == NULL) {
        if (obs_with_id) fprintf(output_file, "%d: (null)\n", string_id);
        else fprintf(output_file, "(null)\n");
        fflush(output_file);
        return;
    }

    if (obs_with_id) fprintf(output_file, "%d: ", string_id);
    i = 0;
    current_path = max_path->first;
    while (current_path != NULL) {
        i++;
        if (!strcmp(current_path->current->label,"end") && print_probs) {
            fprintf(output_file, "%s %s ", current_path->word, current_path->current->label);
            if (print_state_id) fprintf(output_file, "%d ", current_path->current->id);
            if (print_probs) fprintf(output_file, "%f", current_path->ltprob);
        }
        else if (strcmp(current_path->current->label,"start")) { 
            fprintf(output_file, "%s %s ", current_path->word, current_path->current->label);
            if (print_state_id) fprintf(output_file, "%d ", current_path->current->id);
            if (print_probs) fprintf(output_file, "%f %f ", current_path->ltprob, current_path->leprob);
        }
        current_path = current_path->next;
    }
    fprintf(output_file, "\n");
    fflush(output_file);
}

void collect_pp_stats(path_head *max_path, int *ptr_total_obs, double *ptr_total_obs_logprob, double *ptr_total_t_logprob, double *ptr_total_e_logprob) {
    static char rname[]="calc_pp_stats";
    path *current_path;
    int total_obs;
    double total_obs_lp, total_t_lp, total_e_lp;

    if (max_path == NULL) return;

    total_obs = 0;
    total_obs_lp = 0.0;
    total_t_lp = 0.0;
    total_e_lp = 0.0;

    current_path = max_path->first;
    while (current_path != NULL) {

        if (strcmp(current_path->current->label,"start")) {
            total_obs_lp += current_path->ltprob + current_path->leprob;
            total_e_lp += current_path->leprob;
            total_t_lp += current_path->ltprob;
            total_obs++;
        }
        current_path = current_path->next;
    }

    *ptr_total_obs += total_obs;
    *ptr_total_obs_logprob += total_obs_lp;
    *ptr_total_t_logprob += total_t_lp;
    *ptr_total_e_logprob += total_e_lp;
}



void print_forward_prob_to_file(double lprob, FILE *output_file, shead *start, int obs_with_id) {
    static char rname[]="print_forward_prob_to_file";
    int i, string_id;

    if (obs_with_id) {
        string_id = start->count;
        if (lprob == no_prob) fprintf(output_file, "%d: -inf (total logprob, base e)\n", string_id);
        else fprintf(output_file, "%d: %f (total logprob, base e)\n", string_id, lprob);
        fflush(output_file);
        return;
    }
    else {
        if (lprob == no_prob) fprintf(output_file, "-inf (total logprob, base e)\n");
        else fprintf(output_file, "%f (total logprob, base e)\n", lprob);
    }
}

void free_path(path_head *vit_path) {
    static char rname[]="free_path";
    path *next, *prev;

    if (vit_path == NULL) return;

    next = vit_path->first;
    free(vit_path);

    while (next != NULL) {
        free(next->word);
        prev = next;
        next = next->next;
        free(prev);
    }
}


double get_emission_lprob(state *dest, char *word) {
    static char rname[]="get_emission_lprob";
    double logprob;
    int word_id, mode;

    mode = 1;

    word_id = bow_word2int_no_add(word);

    if ((word_id == -1) || (word_id > dest->O->vocab_size)) { /* Word does not exist in distribution */
        word_id = bow_word2int_no_add(unk_word);
        if ((word_id == -1) || (word_id > dest->O->vocab_size)) {
            return(no_prob);
        }
    }

    if (dest->O->lprobs != NULL) {
        logprob = dest->O->lprobs[word_id];
    }
    else if (dest->O->vocab_size != -1) { /* exclude null dists */
        logprob = calc_lprob(dest->O, word_id, mode);
    }
    else {
        return(no_prob);
    }

    if (logprob < -98.0) return(no_prob);
    else return(logprob);
}


int count_num_symbols(shead *data) {
    static char rname[]="count_num_symbols";
    int num_steps;
    sdata *current_data;

    /* Find the number of symbols in the data string - this will be
       the number of steps to take through the HMM. Add 1 for the
       end state (nothing is added for the start state). */

    num_steps = 0;
    current_data = data->string;
    while (current_data != NULL) {
        num_steps++;
        current_data = current_data->next;
    }
    num_steps++;  /* For end state */
    if (DEBUG) fprintf(stderr, "Data string contains %d symbols (including end state)\n", num_steps);
   
    return(num_steps);
}

alpha *create_alpha(state *dest, double trans_prob, float trans_weight, double emission_prob, double alpha_value, double alpha_tvalue, double alpha_evalue) {
    alpha *new_alpha;

    new_alpha = (alpha *) kmalloc(sizeof(alpha));
    new_alpha->current = dest;
    if (trans_weight < 1.0 || trans_weight > 1.0) 
        new_alpha->prob = alpha_value * pow(trans_prob, trans_weight) * emission_prob;
    else
        new_alpha->prob = alpha_value * trans_prob * emission_prob;
    new_alpha->tprob = alpha_tvalue * trans_prob;
    new_alpha->eprob = alpha_evalue * emission_prob;
    new_alpha->next = NULL;

    return(new_alpha);
}

void update_alpha(alpha *temp_alpha, double trans_prob, float trans_weight, double alpha_value, double alpha_tvalue, double alpha_evalue) {

    if (trans_weight < 1.0 || trans_weight > 1.0) 
        temp_alpha->prob += (alpha_value * pow(trans_prob, trans_weight));
    else 
        temp_alpha->prob += (alpha_value * trans_prob);
    temp_alpha->tprob += (alpha_tvalue * trans_prob);
    temp_alpha->eprob += alpha_evalue;
}


void scale_alphas(int current_step, alpha **alphas, double *scale, double *escale, double *tscale) {
    alpha *last_alpha;
    double max_alpha, max_ealpha, max_talpha;

    max_alpha = 0.0;
    max_ealpha = 0.0;
    max_talpha = 0.0;
    last_alpha = alphas[current_step];
    while (last_alpha != NULL) {
        if (last_alpha->prob > max_alpha) max_alpha = last_alpha->prob;
        if (last_alpha->eprob > max_ealpha) max_ealpha = last_alpha->eprob;
        if (last_alpha->tprob > max_talpha) max_talpha = last_alpha->tprob;
        last_alpha = last_alpha->next;
    }
    scale[current_step] = max_alpha;
    escale[current_step] = max_ealpha;
    tscale[current_step] = max_talpha;
    last_alpha = alphas[current_step];
    while (last_alpha != NULL) {
        last_alpha->prob /= scale[current_step];
        last_alpha->eprob /= escale[current_step];
        last_alpha->tprob /= tscale[current_step];
        last_alpha = last_alpha->next;
    }
}

double calc_forward_prob(shead *data, state *initial, int details, int punc_trans, int read_label, float trans_weight, double *elprob, double *tlprob) {
    static char rname[]="calc_forward_prob";
    sdata *current_data;
    state *dest, *source;
    trans *current_trans;
    alpha **alphas, *new_alpha, *last_alpha, *current_alpha, *temp_alpha, *prev_alpha;
    char word[1000], prev_word[1000];
    int num_steps, found, i, current_step, word_id, allow_transition;
    double trans_prob, emission_prob, total_lprob, alpha_value, *scale, max_alpha;
    double alpha_tvalue, alpha_evalue, total_elprob, total_tlprob, *escale, *tscale, max_ealpha, max_talpha;
    float emission_lprob;


    if (data->string == NULL) {
        return(no_prob);
    }

    /* Find the number of symbols in the data string */
    num_steps = count_num_symbols(data);

    /* Initialize data structures with enough space for number of states and steps */
    /* At each step we'll keep a linked list of possible state transitions */
    alphas = (alpha **) kmalloc((num_steps + 1) * sizeof(alpha *));
    for (i = 0; i <= num_steps; i++) {
        alphas[i] = NULL;
    }
    scale = (double *) kmalloc((num_steps + 1) * sizeof(double));
    for(i = 0; i <= num_steps; i++) {
        scale[i] = 1.0;
    }
    escale = (double *) kmalloc((num_steps + 1) * sizeof(double));
    for(i = 0; i <= num_steps; i++) {
        escale[i] = 1.0;
    }
    tscale = (double *) kmalloc((num_steps + 1) * sizeof(double));
    for(i = 0; i <= num_steps; i++) {
        tscale[i] = 1.0;
    }



    /* Step 1 - Initialization */
    /* ----------------------- */

    /* Get first word in data string */
    current_step = 1;
    current_data = data->string;
    strcpy(prev_word,current_data->word);
    get_word(prev_word, word);

    /* Find non-zero transitions out of initial state, and which states could emit word */
    current_trans = initial->out;
    while (current_trans != NULL) {
        dest = current_trans->dest;

        /* If constrained search, make sure labels match */
        if (read_label && strcmp(dest->label, current_data->label)) {
            current_trans = current_trans->next_source;
            continue;
        }

        trans_prob = current_trans->prob;
        emission_lprob = get_emission_lprob(dest, word);

        /* Details */
        if (details) fprintf(stderr, "Step: %d, source = %s, dest = %s, 
            tprob = %f, eprob = %f\n", current_step, current_trans->source->label, 
            current_trans->dest->label, trans_prob, emission_prob);

        /* Proceed if state can emit current word */
        if (emission_lprob < no_prob) { /* hence, prob exists */
            emission_prob = exp(emission_lprob);

            /* Add another instance to alpha structure */
            new_alpha = create_alpha(dest, trans_prob, trans_weight, emission_prob, 1.0, 1.0, 1.0);

            if (alphas[current_step] == NULL) alphas[current_step] = new_alpha;
            else {
                last_alpha->next = new_alpha;
            }
            last_alpha = new_alpha;
        }
        current_trans = current_trans->next_source;
    }

    /* Scale alphas */
    scale_alphas(current_step, alphas, scale, escale, tscale);

    /* Details */
    if (details) {
        fprintf(stderr, "Alphas of current step = %d, word = %s, scale = %f\n", current_step, word, scale[current_step]);
        last_alpha = alphas[current_step];
        while (last_alpha != NULL) {
            fprintf(stderr, "   state = %s, prob = %.20f\n", last_alpha->current->label, last_alpha->prob);
            last_alpha = last_alpha->next;
        }
    }

    /* Step 2 - Recursion */
    /* ------------------ */

    /* Get next word in data string */
    current_data = current_data->next;
    while (current_data != NULL) {
        current_step++;
        if (punc_trans) { 
            allow_transition = ends_in_punc(prev_word);
            if (DEBUG) fprintf(stderr, "Prev word = %s, allow_transitions = %d\n", prev_word, allow_transition);
        }
        strcpy(prev_word,current_data->word);
        get_word(prev_word, word);

        /* Get states from each alpha record from previous step, and look at all the states
           they can transition into */
        current_alpha = alphas[current_step-1];
        while (current_alpha != NULL) {
            alpha_value = current_alpha->prob;
            alpha_tvalue = current_alpha->tprob;
            alpha_evalue = current_alpha->eprob;

            source = current_alpha->current;
            current_trans = source->out;
            while (current_trans != NULL) {
                dest = current_trans->dest;

                /* If constrained search, make sure labels match */
                if (read_label && strcmp(dest->label, current_data->label)) {
                    current_trans = current_trans->next_source;
                    continue;
                }

                trans_prob = current_trans->prob;

                /* If state transitions are only allowed after words ending in punctuation (.,:!),
                        if previous word does not end in punctuation, only consider transitions back to the
                        current state */
                if (punc_trans && !allow_transition) {
                    if (dest != current_trans->source) {
                        current_trans = current_trans->next_source;
                        continue;
                    }
                }

                /* See if destination state already exists on linked list - 
			if it does, add in this term. If not, create a new alpha record. */
                temp_alpha = alphas[current_step];
                found = 0;
                while (temp_alpha != NULL && !found) {
                    if (temp_alpha->current == dest) {
                        found = 1;
                        update_alpha(temp_alpha, trans_prob, trans_weight, alpha_value, alpha_tvalue, alpha_evalue);
                        break;
                    }
                    temp_alpha = temp_alpha->next;
                }
                if (!found) {
                    /* Add another instance to alpha structure */
                    new_alpha = create_alpha(dest, trans_prob, trans_weight, 1.0, alpha_value, alpha_tvalue, alpha_evalue);

                    if (alphas[current_step] == NULL) alphas[current_step] = new_alpha;
                    else {
                        last_alpha->next = new_alpha;
                    }
                    last_alpha = new_alpha;
                }
                current_trans = current_trans->next_source;
            }
            current_alpha = current_alpha->next;
        }

        /* Now that we've gone through all the alphas and created new alpha records for
		the current time step, we need to add in the emission probs */

        temp_alpha = alphas[current_step];
        prev_alpha = NULL;
        while (temp_alpha != NULL) {
            emission_lprob = get_emission_lprob(temp_alpha->current, word);

            if (emission_lprob < no_prob) {
                emission_prob = exp(emission_lprob);
                temp_alpha->prob *= emission_prob;
                temp_alpha->eprob *= emission_prob;
                prev_alpha = temp_alpha;
                temp_alpha = temp_alpha->next;
            }
            else {
                /* This state cannot emit this word... Remove alpha record for this state */
                if (prev_alpha != NULL) {
                    prev_alpha->next = temp_alpha->next;
                    free(temp_alpha); 
                    temp_alpha = prev_alpha->next;
                }
                else {
                    alphas[current_step] = temp_alpha->next;
                    free(temp_alpha);
                    temp_alpha = alphas[current_step];
                }
            } 
        }

        /* Scale alphas */
        scale_alphas(current_step, alphas, scale, escale, tscale);

        if (details) {
            fprintf(stderr, "Current step = %d, word = %s, scale = %f\n", current_step, word, scale[current_step]);
            last_alpha = alphas[current_step];
            while (last_alpha != NULL) {
                fprintf(stderr, "   state = %s, prob = %.20f\n", last_alpha->current->label, last_alpha->prob);
                last_alpha = last_alpha->next;
            }
        }
        current_data = current_data->next;
    }

    /* Get states from each alpha record from previous step, and see if they can transition into end state */
    current_step++;
    current_alpha = alphas[current_step-1];
    while (current_alpha != NULL) {
        alpha_value = current_alpha->prob;
        alpha_tvalue = current_alpha->tprob;
        alpha_evalue = current_alpha->eprob;

        source = current_alpha->current;
        current_trans = source->out;
        while (current_trans != NULL) {
            dest = current_trans->dest;

            /* Only consider transitions into the end state */
            if (!strcmp(dest->label,"end") || !strcmp(dest->label,"e")) {
                trans_prob = current_trans->prob;

                /* There will only be an alpha for the transition into the end state */
                if (alphas[current_step] != NULL)
                    update_alpha(alphas[current_step], trans_prob, trans_weight, alpha_value, alpha_tvalue, alpha_evalue);
                else {
                    /* Add new instance to alpha structure */
                    new_alpha = create_alpha(dest, trans_prob, trans_weight, 1.0, alpha_value, alpha_tvalue, alpha_evalue);
                    if (alphas[current_step] == NULL) alphas[current_step] = new_alpha;
                }
            }
            current_trans = current_trans->next_source;
        }
        current_alpha = current_alpha->next;
    }

    /* Scale alphas */
    scale_alphas(current_step, alphas, scale, escale, tscale);

    /* Don't need to add in emission prob for last step... */

    if (current_step != num_steps) quit(-1, "%s: current step (%d) does not equal total number of steps (%d)!\n", rname, current_step, num_steps);

    /* Step 3 - Termination */
    /* -------------------- */

    /* Sum up all the alphas in the end state -
       there should only be one alpha at this point */

    if (alphas[current_step] == NULL) {
        free_alphas(alphas, num_steps);
        return(no_prob);
    }

    if (alphas[current_step]->next != NULL) quit(-1, "%s: in forward search, there was more than one alpha at end state???\n", rname);

    /* Calculate total observation prob */
    total_lprob = 0.0;
    total_elprob = 0.0;
    total_tlprob = 0.0;
    for (current_step = 1; current_step <= num_steps; current_step++) {
        total_lprob += log(scale[current_step]);
        total_elprob += log(escale[current_step]);
        total_tlprob += log(tscale[current_step]);
    }

    *elprob = total_elprob;
    *tlprob = total_tlprob;

    free_alphas(alphas, num_steps);
    free(scale);

    return(total_lprob);
}


path_head *find_max_path(shead *data, state **states, int num_states, int read_label) {
    static char rname[]="find_max_path";
    state *max_state;
    sdata *current_data;
    path *last_path_state, *current_path_state;
    path_head *max_path;
    char word[1000], orig_word[1000];
    int i, word_id;
    double emission_lprob, max_lprob;

    if (data == NULL) return(NULL);
    if (data->string == NULL) return(NULL);

    max_path = (path_head *) kmalloc(sizeof(path_head));
    last_path_state = NULL;

    /* Get each word in data string */
    current_data = data->string;
    while (current_data != NULL) {
        strcpy(orig_word,current_data->word);
        get_word(orig_word, word);

        /* For each word, check emission prob from all states */
        max_lprob = -1000000.0;
        max_state = NULL;
        for (i = 0; i < num_states; i++) {

            /* Skip state if it's a start or end state */
            if (!strcmp(states[i]->label,"start") || 
                !strcmp(states[i]->label,"end")) continue;

            /* If constrained search, make sure labels match */
            if (read_label && strcmp(states[i]->label, current_data->label)) continue;

            emission_lprob = get_emission_lprob(states[i], word);

            if (DEBUG) fprintf(stderr, "word = %s, state = %s, lprob = %f\n", word, states[i]->label, emission_lprob);

            if (emission_lprob > max_lprob) {
                max_lprob = emission_lprob;
                max_state = states[i];
            }
        }

        /* Assign current state as one that gives the max emission prob for current word */
        current_path_state = (path *) kmalloc(sizeof(path));
        current_path_state->current = max_state;
        current_path_state->word = strdup (orig_word);
        current_path_state->ltprob = 0.0;
        current_path_state->leprob = max_lprob;
        current_path_state->next = NULL;

        if (last_path_state == NULL) max_path->first = current_path_state;
        else last_path_state->next = current_path_state;

        last_path_state = current_path_state;

        current_data = current_data->next;
    }

    return(max_path);
}

