/* model.c - routines for creating and manipulating HMMs */

#include "general.h"

shead **read_strings(FILE *string_file, int *p_num_strings, int read_count, int read_label, int read_id) {
    static char rname[]="read_strings";
    int i, count, num_strings, new_word, word_next, word_id, max_id;
    shead **strings;
    sdata *temp, *prev;
    char line[30000], *word, *label, *temp_word;

    /* Words are mapped to the vocabulary - unknown words are mapped to unk_word if included in vocabulary */

    max_id = 0;
    if (!read_label) label = strdup("none");

    /* Count the number of strings that are in the file */
    num_strings = 0;
    while (fgets(line, 30000, string_file)) {
        num_strings++;
    }
    fprintf(stderr, "%d lines read from data file\n", num_strings);
  
    /* Initialize string data structure */ 
    strings = (shead **) kmalloc((num_strings+1)*sizeof(shead *));
    for (i = 0; i < num_strings; i++) {
        strings[i] = (shead *) kmalloc(sizeof(shead));
        strings[i]->string = NULL;
        strings[i]->count = 0;
    }

    /* Read in strings */
    rewind(string_file);
    for (i = 0; i < num_strings; i++) {

        if (fgets(line, 30000, string_file) == NULL) 
            quit(-1, "%s: wasn't able to read as many strings as expected...\n", rname);
        word = strtok(line, " \n");

        if (word == NULL) quit(-1,"%s: first word of line is null...\n", rname);

        if (read_count || read_id) {
            sscanf(word, "%d", &count);

            if (read_count) strings[i]->count = count;
            else if (read_id) strings[i]->count = 1;

            word = strtok(NULL, " \n");
        }
        else strings[i]->count = 1;

        prev = NULL;
        while (word != NULL) {

            temp_word = strdup(word);
            remove_punc(temp_word);

            if (read_label) {
                label = strtok(NULL, " \n");
                if (label == NULL)
                    quit(-1, "%s: could not read a label after word %s\n", rname, word);
            }

            temp = (sdata *) kmalloc(sizeof(sdata));
            word_id = bow_word2int_no_add(temp_word);
            if (word_id == -1) {
                word_id = bow_word2int_no_add(unk_word);
                if (word_id == -1) quit(-1, "%s: word %s is OOV, and the unknown word isn't in vocabulary either!\n", rname, temp_word);
                else temp->word = strdup (unk_word);
            }
            else temp->word = strdup (temp_word);

            temp->label = strdup(label);
            temp->next = NULL;

            if (strings[i]->string == NULL) strings[i]->string = temp;
            if (prev != NULL) prev->next = temp;
            prev = temp;

            free(temp_word);
            word = strtok(NULL, " \n");
        }
        if (DEBUG) fprintf(stderr, "Reached end of one line\n");
    }
    if (DEBUG) fprintf(stderr, "Read in %d strings successfully\n", num_strings);

    *p_num_strings = num_strings;
    if (!read_label) free(label);

    return(strings);
}

shead *read_one_string(FILE *string_file, int read_id, int read_label, int read_count) {
    static char rname[]="read_one_string";
    int string_id, stop, length;
    shead *start;
    sdata *temp, *prev;
    char *word, *label, testchar, line[50000];

    /* Words are not mapped to a vocabulary - no words are mapped to unk_word here */

    if (fgets(line, 10000, string_file) == NULL) return(NULL); 

    if (!read_label) label = strdup("none");

    /* Initialize string data structure */
    start = (shead *) kmalloc(sizeof(shead));
    start->string = NULL;
    temp = NULL;
    prev = NULL;

    word = strtok(line, " \n");
    if (word == NULL) quit(-1,"%s: first word of line is null... what's going on?\n", rname);

    /* Read in string */
    if (read_id || read_count) {
        sscanf(word, "%d", &string_id); 
        if (read_count) start->count = string_id;
        if (read_id) start->count = 1;
        word = strtok(NULL, " \n");
    }
    else start->count = 1;    /* bogus count - just for initialization */

    while (word != NULL) {

        if (read_label) label = strtok(NULL, " \n");

        if (label == NULL)
            quit(-1, "%s: could not read a label after word %s\n", rname, word);

        if (DEBUG) fprintf(stderr, "word = %s, label = %s\n", word, label);

        temp = (sdata *) kmalloc(sizeof(sdata));
        temp->word = strdup(word);
        temp->label = strdup(label);
        temp->next = NULL;

        if (start->string == NULL) start->string = temp;
        if (prev != NULL) prev->next = temp;
        prev = temp;

        word = strtok(NULL, " \n");
    }
    if (!read_label) free(label);

    return(start);
}

int count_sample_strings(FILE *string_file, int read_count) {
    static char rname[]="count_sample_strings";
    int num_strings, num_samples, i, count;
    char line[30000], *word;

    /* Count the number of strings are in the file */
    num_strings = 0;
    while (fgets(line, 30000, string_file)) {
        num_strings++;
    }
    rewind(string_file);
    if (!read_count) return(num_strings);

    /* Read in string counts */
    num_samples = 0;
    for (i = 0; i < num_strings; i++) {

        if (fgets(line, 30000, string_file) == NULL)
            quit(-1, "%s: wasn't able to read as many strings as expected...\n", rname);
        word = strtok(line, " \n");
        sscanf(word, "%d", &count);

        num_samples += count;

    }
    rewind(string_file);

    return(num_samples);
}


void map_words_to_unk(shead *new_string) {
    static char rname[]="map_words_to_unk";
    char word[200];
    sdata *current;
    int word_id;

    /* Read string from string data structure */
    current = new_string->string;
    while (current != NULL) {

        remove_punc(current->word);

        word_id = bow_word2int_no_add(current->word);
        if (word_id == -1) {
            word_id = bow_word2int_no_add(unk_word);
            if (word_id == -1) quit(-1, "%s:  word %s is OOV, and the unknown word isn't in vocabulary either!\n", rname, current->word);
            else {
                free(current->word);
                current->word = strdup(unk_word);
            }
        }
        current = current->next;
    }
}

void free_string(shead *start) {
    static char rname[]="free_string";
    sdata *next, *prev;

    next = start->string;
    free(start);
    while (next != NULL) {
       free(next->word);
       free(next->label);
       prev = next;
       next = next->next;
       free(prev);
    } 
}


void create_initial_model_from_data(shead **strings, int *num_of_states, int num_strings, double trans_uni_alpha, double obs_uni_alpha, state **p_initial, state **p_end, count_dist **priors, int num_prior_dists, int vocab_size) {
    static char rname[]="create_initial_model_from_data";
    int num_states, i;
    double prob;
    char dist_label[200], prior_label[200], word[200];
    state *start, *end, *temp, *prev;
    sdata *current;
    multinomial *start_null_dist, *end_null_dist, *new_dist;
    count_dist *prior_dist;
    static char null_word[] = { "null" };

    /* Create start and end states */
    start_null_dist = create_distribution(null_word, NULL, NULL, 0, NULL, vocab_size, 0, 0);
    end_null_dist = create_distribution(null_word, NULL, NULL, 0, NULL, vocab_size, 0, 0);
    start = create_state(1, "start", start_null_dist, 1);
    end = create_state(0, "end", end_null_dist, 1);
    *p_initial = start;
    *p_end = end;
    if (DEBUG) fprintf(stderr, "Created initial and final states...\n");

    /* Read strings from string data structure */
    num_states = 1;
    for (i = 0; i < num_strings; i++) {
        if (strings[i] == NULL) { break; }

        if (DEBUG) print_observation_string(strings[i]);

        /* See if string already exists in model */
        if (path_already_exists(start, strings[i])) {
            if (DEBUG) fprintf(stderr, "Duplicate line detected...\n");
            strings[i]->count++;
        }
        else {
            current = strings[i]->string;

            if (DEBUG) fprintf(stderr, "Adding line to model\n");

            /* Add states to model for this string */
            prev = start;
            while (current != NULL) {
                num_states++;

                strcpy(word, current->word);

                sprintf(dist_label, "d%d", num_states);
                strcpy(prior_label, current->label);
                if (priors != NULL) prior_dist = matching_prior(prior_label, priors, num_prior_dists);
                else prior_dist = NULL;

                if (DEBUG) fprintf(stderr, "Creating state %d with label %s, dist label %s, prior label %s\n", num_states, current->label, dist_label, prior_label);

                new_dist = create_distribution(dist_label, NULL, word, strings[i]->count, prior_dist, vocab_size, obs_uni_alpha, 0);

                temp = create_state(num_states, current->label, new_dist, 1);
 
                /* Set transition count */
                prob = 0.0;
                add_trans(prev, temp, strings[i]->count, trans_uni_alpha, prob);

                prev = temp;
                current = current->next;
            }

            if (DEBUG) fprintf(stderr, "Reached end of one string\n");

            prob = 0.0;
            add_trans(temp, end, strings[i]->count, trans_uni_alpha, prob);
        }
    }
    num_states++; /* For end state */
    fprintf(stderr, "Initial model consists of %d states, from %d strings\n", num_states, num_strings);

    if (DEBUG) fprintf(stderr, "Exiting create_initial_model routine\n");
    *num_of_states = num_states;
}

void add_string_to_model(shead *new_string, state *initial, state *end, int *ptr_num_states, int *ptr_max_state_label, count_dist **priors, int num_prior_dists, int vocab_size, double trans_uni_alpha, double obs_uni_alpha) {
    static char rname[]="add_string_to_model";
    int num_new_states, state_id, max_state_label, num_states;
    double prob;
    char dist_label[200], prior_label[200], word[200];
    state *temp, *prev;
    sdata *current;
    multinomial *new_dist;
    count_dist *prior_dist;

    num_new_states = 0;
    num_states = *ptr_num_states;
    max_state_label = *ptr_max_state_label;

    /* Read string from string data structure */
    current = new_string->string;
    prev = initial;
    while (current != NULL) {
        num_new_states++;

        max_state_label++;
        state_id = max_state_label;

        strcpy(word, current->word);
        remove_punc(word);

        sprintf(dist_label, "d%d", state_id);
        strcpy(prior_label, current->label);
        if (priors != NULL) prior_dist = matching_prior(prior_label, priors, num_prior_dists);
        else prior_dist = NULL;

        new_dist = create_distribution(dist_label, NULL, word, new_string->count, prior_dist, vocab_size, obs_uni_alpha, 0);

        if (DEBUG) fprintf(stderr, "Creating state %d with label %s, dist label %s, prior label %s\n", state_id, current->label, dist_label, prior_label);

        temp = create_state(state_id, current->label, new_dist, 1);

        /* Set transition count */
        prob = 0.0;
        add_trans(prev, temp, new_string->count, trans_uni_alpha, prob);

        prev = temp;
        current = current->next;
    }

    prob = 0.0;
    add_trans(temp, end, new_string->count, trans_uni_alpha, prob);
    num_new_states++; /* For end state */

    num_states += num_new_states;
    if (DEBUG) fprintf(stderr, "Added %d states, now have %d states\n", num_new_states, num_states);

    *ptr_num_states = num_states;
    *ptr_max_state_label = max_state_label;
}

state *read_model_from_file(char *model_file_path, int *p_num_states) { 
    static char rname[]="read_model_from_file";
    FILE *model_file;
    int num_dists, num_states, i, state_id, dist_index, start, found_s, found_e;
    int from_state_id, from_state_index, to_state_id, to_state_index, dur_value;
    int count;
    float prob;
    char dist_label[200], dist_path[500], state_label[200];
    multinomial **dists;
    state **states;
    static char null_word[] = { "null" };

    start = -1;

    model_file = kopen_r(model_file_path);

    /* Read in distribution info */
    fscanf(model_file, "%d\n", &num_dists);
    fprintf(stderr, "There are %d distributions to be read in:\n", num_dists);

    /* Add one to num_dists for the null distribution */
    num_dists++;
    dists = (multinomial **) kmalloc((num_dists+1)*sizeof(multinomial *));
    if (dists == NULL)
        quit(-1, "%s: could not allocate memory for dists array\n",rname);

    /* Initialize distribution pointers */
    for (i= 0; i < num_dists; i++) {
        dists[i] = NULL;
    }

    /* Create the null distribution */
    dists[0] = create_distribution(null_word, NULL, NULL, 0, NULL, 0, 0, 0);

    /* Read in distribution file name and load distribution */
    for (i = 1; i < num_dists; i++) {
        fscanf(model_file, "%s %s\n", dist_label, dist_path);
        fprintf(stderr, "i = %d, %s:\t%s\n", i, dist_label, dist_path);
        dists[i] = create_distribution(dist_label, dist_path, NULL, 0, NULL, 0, 0, 0);
    }

    /* Read in state information */
    fscanf(model_file, "%d\n", &num_states);
    fprintf(stderr, "There are %d states to be created:\n", num_states);
    states = (state **) kmalloc((num_states+1)*sizeof(state *));
    for (i = 0; i < num_states; i++) {
        states[i] = NULL;
    }

    /* Read in list of states and create each */
    found_s = 0; found_e = 0;
    for (i = 0; i < num_states; i++) {
        fscanf(model_file, "%d %s %d %s\n", &state_id, state_label, &dur_value, dist_label);
        if (DEBUG) fprintf(stderr, "State %d:\tlabel = %s, distribution = %s, duration value = %d\n", state_id, state_label, dist_label, dur_value);
        dist_index = get_dist_index(dists, dist_label, num_dists);
        if (dist_index == -1)
            quit(-1, "%s: no valid distribution was found for label %s\n", rname, dist_label);
        states[i] = create_state(state_id, state_label, dists[dist_index], dur_value);
        if (!strcmp(state_label,"start")) {
            start = i;
            found_s = 1;
            if (DEBUG) fprintf(stderr, "Set start state...\n");
        }
        if (!strcmp(state_label,"end") || !strcmp(state_label,"e")) {
            found_e = 1;
        }
    }
    if (!found_s) quit(-1, "%s: model does not contain a start state with label start!!\n", rname);
    if (!found_e) quit(-1, "%s: model does not contain an end state with label end!!\n",rname);

    /* Create transitions between states */
    fprintf(stderr, "Reading in transitions...\n");
    while (!feof(model_file)) {
        fscanf(model_file, "%d %d %f\n",&from_state_id, &to_state_id, &prob);
        if (DEBUG) fprintf(stderr, "From %d to %d with prob %f\n", from_state_id, to_state_id, prob);
        from_state_index = get_state_index(states, from_state_id, num_states);
        to_state_index = get_state_index(states, to_state_id, num_states);
        if (DEBUG) fprintf(stderr, "add_trans: %d, %d, %f\n", from_state_index, to_state_index, prob);
        add_trans(states[from_state_index], states[to_state_index], 0, 0, (double) prob);
    }
    fclose(model_file);
    fprintf(stderr, "Finished reading in model\n");

    *p_num_states = num_states;

    return(states[start]); 
}

void print_state_info(state *current_s, int mode) {
    static char rname[]="print_state_info";
    trans *current_t;
    float prob;
    int i, count;
    double alpha;
    tc *pcount;

    fprintf(stderr, "State %i:\n", current_s->id);
    fprintf(stderr, "    Label: %s\n", current_s->label);
    fprintf(stderr, "    Minimum duration: %d\n", current_s->duration);
    fprintf(stderr, "    Seen: %d\n", current_s->seen);
    fprintf(stderr, "    Log prior: %f\n", current_s->lprior);
    fprintf(stderr, "    Log likelihood: %f\n", current_s->llikelihood);

    /* Print more multinomial information?? */
    if (current_s->O != NULL) {
        fprintf(stderr, "    Observation distribution:\n");
        fprintf(stderr, "       Label: %s\n",current_s->O->label);
        fprintf(stderr, "       Path: %s\n",current_s->O->path);
        fprintf(stderr, "       Num types: %d\n",current_s->O->num_types);
        fprintf(stderr, "       Num tokens: %d\n",current_s->O->num_tokens);
        fprintf(stderr, "       Vocab size: %d\n",current_s->O->vocab_size+1);
        fprintf(stderr, "       Uniform alpha count: %f\n",current_s->O->uni_alpha);
        fprintf(stderr, "       Total count: %f\n",current_s->O->total_count);
        fprintf(stderr, "       Prior weight adjustment: %f\n",current_s->O->prior_weight_adjustment);
        if (current_s->O->prior != NULL) {
            fprintf(stderr, "           Prior label: %s\n",current_s->O->prior->label);
            fprintf(stderr, "           Prior path: %s\n",current_s->O->prior->path);
            fprintf(stderr, "           Prior num types: %d\n",current_s->O->prior->num_types);
            fprintf(stderr, "           Prior num tokens: %f\n",current_s->O->prior->num_tokens);
            fprintf(stderr, "           Prior vocab size: %d\n",current_s->O->prior->vocab_size);
            fprintf(stderr, "           Prior log beta: %f\n",current_s->O->prior->log_beta);
        }
        pcount = current_s->O->counts;
        while (pcount != NULL) {
            count = pcount->count;

            if (current_s->O->total_count <= 0) prob = 0.0;
            else prob = exp(calc_lprob(current_s->O, pcount->id, mode));

            alpha = calc_emis_alpha(current_s->O, pcount->id, current_s->O->prior_weight_adjustment);
            fprintf(stderr, "      count = %d\talpha = %f\tprob = %f\n", count, alpha, prob);
            pcount = pcount->next;
        }
    }

    current_t = current_s->out;
    if (current_t != NULL) fprintf(stderr, "    Transitions to: \n");
    while (current_t != NULL) {
        alpha = calc_trans_alpha(current_s, current_t->alpha);
        fprintf(stderr, "       State: %i\tcount: %d\talpha: %.3f\t prob: %f\n", current_t->dest->id, current_t->count, alpha, current_t->prob);
        current_t = current_t->next_source;
    }

}

/* ORPHAN */
void print_children_of(state *current_s, int *printed, int mode) {
    static char rname[]="print_children_of";
    trans *current_t;

    current_t = current_s->out;
    while (current_t != NULL) {
        if (!printed[current_t->dest->id]) {
            print_state_info(current_t->dest, mode);
            printed[current_t->dest->id] = 1;
            print_children_of(current_t->dest, printed, mode);
        }
        current_t = current_t->next_source;
    }
}

void print_model(state **states, int num_states, int mode) {
    static char rname[]="print_model";
    int   i;

    for (i = 0; i < num_states; i++) {
        if (states[i] != NULL) {
            print_state_info(states[i], mode);
        }
    }
}

void print_model_to_file(FILE *outfile, char *output_dir, char *vocab_file, state **states, int num_states, char *emissions_dir, int smooth, int mode, int iteration) {
    static char rname[]="print_model_to_file";
    int i, num_dists;
    ltrans *first_trans, *current_trans;
    multinomial **dists;
    char dist_file_path[1000], output_file_path[1000];

    if (outfile == NULL) {
        sprintf(output_file_path, "%s/iter_%d_%d.model.gz", output_dir, iteration, num_states);
        outfile = kopen_wgz(output_file_path);
    }

    /* Initialize distribution array - at most there are as many distributions
	as there are states */
    dists = (multinomial **) kmalloc((num_states+1) * sizeof(multinomial *));
    for (i = 0; i < num_states; i++) {
        dists[i] = NULL;
    }

    /* Initialize transitions - don't know how many there will be */
    first_trans = NULL;

    /* Go through model and collect pointers to the states, distributions and transitions */
    for (i = 0; i < num_states; i++) {
        collect_state_info(states[i], num_states, dists, &first_trans);
    }

    /* Print distribution information to output file */
    num_dists = 0;
    for (i = 0; i < num_states; i++) {
        if (dists[i] != NULL) num_dists++;
        else break;
    }
    if (DEBUG) fprintf(stderr, "%s: Found %d distributions\n", rname, num_dists);

    fprintf(outfile, "%d\n", num_dists);
    for (i = 0; i < num_dists; i++) {
        if (emissions_dir != NULL) {
            sprintf(dist_file_path, "%s/iter_%d_%d_%s.arpa.gz", emissions_dir, iteration, num_states, dists[i]->label);
            fprintf(outfile, "%s %s\n", dists[i]->label, dist_file_path);

            fprintf(stderr, "Printing %s\n", dist_file_path);

            print_distribution(dist_file_path, vocab_file, dists[i], smooth, mode);
        }
    }

    /* Print state information to output file */
    fprintf(outfile, "%d\n", num_states);
    for (i = 0; i < num_states; i++) {
        fprintf(outfile, "%d %s %d %s\n", states[i]->id, states[i]->label, states[i]->duration, states[i]->O->label);
    }

    /* Print transition information to output file */
    current_trans = first_trans;
    while (current_trans != NULL) {
        fprintf(outfile, "%d %d %.20f\n", current_trans->trans->source->id, current_trans->trans->dest->id, current_trans->trans->prob);
        current_trans = current_trans->next;
    }

    fclose(outfile);
    free(dists);
    free_linked_transitions(first_trans);

    fprintf(stderr, "Printed model to file %s\n", output_file_path);

}

void print_distribution(char *dist_file_path, char *vocab_file, multinomial *O, int smooth, int mode) {
    static char rname[]="print_distribution";
    FILE *output_file;
    char output_file_path[1000], pipe[256];
    double *prob, zeroton_lprob;

    if (O->counts == NULL) return;

    /* Calculate probability distribution - smooth counts */
    prob = smooth_counts(O, O->vocab_size, smooth, mode, &zeroton_lprob);

    /* Print emission distribution in gzipped format */
    output_file = kopen_wgz(dist_file_path);

    print_abbrev_arpa_unigram(output_file, vocab_file, prob, O->vocab_size, zeroton_lprob);
    fclose(output_file);

    free(prob);
}


double *smooth_counts(multinomial *O, int vocab_size, int smooth, int mode, double *ptr_zeroton_lprob) {
    static char rname[]="smooth_counts";
    double *prob;

    if (smooth == 1) 
        prob = absolute_discounting(O, vocab_size, mode, ptr_zeroton_lprob);
    else if (smooth == 2) 
        prob = three_way_linear_interpolation(O, vocab_size, mode, ptr_zeroton_lprob);
    else 
        prob = maximum_likelihood(O, vocab_size, mode, ptr_zeroton_lprob);

    return(prob);
}



void collect_state_info(state *current_s, int num_states, multinomial **dists, ltrans **first_trans) {
    static char rname[]="collect_state_info";
    trans *current_t;
    int i, found, done;
    ltrans *prev, *temp_trans;

    if (DEBUG) fprintf(stderr, "Collecting state info for state %d\n", current_s->id);

    if (current_s->O->num_types > 0) {

        /* See if the output distribution for this state is already on the distribution list */
        found = 0;
        done = 0;
        for (i = 0; i < num_states; i++) {
            if (dists[i] == NULL) break;
            if (!strcmp(dists[i]->label, current_s->O->label)) {
                found = 1;
                if (DEBUG) fprintf(stderr, "Already processed distribution - returning...\n");
                break;
            }
        }

        /* Add distribution to the distribution array */
        if (!found) {
            for (i = 0; i < num_states; i++) {
                if (dists[i] == NULL) {
                    dists[i] = current_s->O;
                    done = 1;
                    if (DEBUG) fprintf(stderr, "Added distribution to array with index %d\n", i);
                    break;
                }
            }
        }
        if (!done && !found) quit(-1, "%s: not enough space in distribution array...\n",rname);
    }

    /* Find the end of the linked list of transitions */
    if (*first_trans == NULL) prev = NULL;
    else {
        prev = *first_trans;
        while (prev->next != NULL) {
            prev = prev->next;
        }
    }

    /* Add all transitions out of this state to the transition linked list */
    current_t = current_s->out;
    while (current_t != NULL) {

        temp_trans = (ltrans *) kmalloc(sizeof(ltrans));
        temp_trans->trans = current_t;
        temp_trans->next = NULL;

        if (*first_trans == NULL) *first_trans = temp_trans;
        if (prev != NULL) prev->next = temp_trans;

        if (DEBUG) fprintf(stderr, "Added transition from %d to %d\n", temp_trans->trans->source->id, temp_trans->trans->dest->id);

        prev = temp_trans;
        current_t = current_t->next_source;
    }
}


void free_linked_transitions(ltrans *first_trans) {
    static char rname[]="free_linked_transitions";
    ltrans *next, *prev;
  
    next = first_trans;
    while (next != NULL) {
        prev = next;
        next = next->next;
        free(prev);
    }
}

void set_state_prior_adjustment(state *s1, int trans_only, int narrow_emis, double dir_prior_weight) {
    static char rname[]="set_state_prior_adjustment";
    int j;
    trans *current_t;
    tc *temp;
    double total, trans_prior_weight_adjustment;

    if (dir_prior_weight == 0) return;

    /* Transition parameters - Narrow */
    total = 0.0;
    current_t = s1->out;
    while (current_t != NULL) {
        total += current_t->alpha;
        current_t = current_t->next_source;
    }

    if (total == 0) quit(-1, "%s: error - no transition counts out of state %d (%s)\n", rname, s1->id, s1->label);

    s1->trans_prior_weight_adjustment = total / dir_prior_weight;
    if (DEBUG) fprintf(stderr, "%s: state %d (%s) trans prior weight adjustment = %f\n", rname, s1->id, s1->label, s1->trans_prior_weight_adjustment);

    if (trans_only) return;

    total = 0.0;
    if (narrow_emis) { /* Narrow emission parameters */
        temp = s1->O->counts;
        while (temp != NULL) {
            total += calc_emis_alpha(s1->O, temp->id, 1.0);
            temp = temp->next;
        }
    }
    else { /* Broad emission parameters */
        for (j = 0; j <= s1->O->vocab_size; j++) {
            total += calc_emis_alpha(s1->O, j, 1.0);
        }
    }

    if (total == 0) quit(-1, "%s: error - no emission counts out of state %d (%s)\n", rname, s1->id, s1->label);

    s1->O->prior_weight_adjustment = total / dir_prior_weight;
}


void set_state_parameters(state *s1, int mode, int trans_only, int narrow_emis) {
    static char rname[]="set_state_parameters";
    int i, j, count;
    trans *current_t;
    tc *temp;
    double alpha, total, prob;

    /* MAP: mode = 1
       MEAN: mode = 2
       ML: mode = 3
       integration over all parameters: mode = 4 - use MAP for now
    */

    /* Transition parameters - Narrow */
    total = 0.0;
    current_t = s1->out;
    while (current_t != NULL) {
        alpha = calc_trans_alpha(s1, current_t->alpha);
        if (mode == 1) total += (current_t->count + alpha - 1);
        if (mode == 2) total += (current_t->count + alpha);
        if (mode == 3 || mode == 4) total += current_t->count;
        current_t = current_t->next_source;
    }

    current_t = s1->out;
    while (current_t != NULL) {
        alpha = calc_trans_alpha(s1, current_t->alpha);
        if (mode == 1) prob = ((double) current_t->count + alpha - 1) / total;
        if (mode == 2) prob = ((double) current_t->count + alpha) / total;
        if (mode == 3 || mode == 4) prob = ((double) current_t->count) / total;

        if (prob < 0) quit(-1, "%s: error - transition prob is < 0 (prob = %f) - try a larger alpha\n", rname, prob);
        else current_t->prob = prob;

        current_t = current_t->next_source;
    }

    if (trans_only) return;

    total = 0;
    if (narrow_emis) { /* Narrow emission parameters */
        temp = s1->O->counts;
        while (temp != NULL) {
            count = temp->count;
            alpha = calc_emis_alpha(s1->O, temp->id, s1->O->prior_weight_adjustment);

            if (DEBUG) fprintf(stderr, "dist label = %s, mode = %d, total count = %f, count = %d, alpha = %f\n", s1->O->label, mode, total, count, alpha);

            if (mode == 1) total += count + alpha - 1;
            if (mode == 2) total += count + alpha;
            if (mode == 3 || mode == 4) total += count;
            temp = temp->next;
        }
    }
    else { /* Broad emission parameters */
        for (j = 0; j <= s1->O->vocab_size; j++) {
            count = retrieve_count(s1->O,j);
            alpha = calc_emis_alpha(s1->O, j, s1->O->prior_weight_adjustment);
            if (mode == 1) total += count + alpha - 1;
            if (mode == 2) total += count + alpha;
            if (mode == 3 || mode == 4) total += count;
        }
    }
    s1->O->total_count = total;
}


void set_model_parameters(state **states, int num_states, int mode, int narrow_emis, double dir_prior_weight) {
    static char rname[]="set_model_parameters";
    int i, count, j, total, trans_only;
    trans *current_t;

    if (mode == 4) fprintf(stderr, "%s: Warning - using ML estimates despite being in mode 4...\n", rname);

    for (i = 0; i < num_states; i++) {

        if (DEBUG) fprintf(stderr, "%s: State = %d (%s), dir prior weight = %f\n", rname, states[i]->id, states[i]->label, dir_prior_weight);

        if (!strcmp(states[i]->label, "end")) continue;
        else if (!strcmp(states[i]->label, "start")) trans_only = 1;
        else trans_only = 0;

        if (dir_prior_weight > 0) set_state_prior_adjustment(states[i], trans_only, narrow_emis, dir_prior_weight);
        set_state_parameters(states[i], mode, trans_only, narrow_emis);
    }
}

void set_model_priors_and_likelihoods(state **states, int num_states, int mode, int use_prior, int narrow_emis) {
    static char rname[]="set_model_priors_and_likelihoods";
    int i;

    for (i = 0; i < num_states; i++) {

        if (use_prior > 0) 
            states[i]->lprior = compute_state_prior(states[i], num_states, mode, use_prior, narrow_emis);
        else states[i]->lprior = 0.0;

        states[i]->llikelihood = compute_state_likelihood(states[i], num_states, mode, narrow_emis);

        if (DEBUG) fprintf(stderr, "%s: state %d (%s): lprior = %f, llikelihood = %f\n", rname, states[i]->id, states[i]->label, states[i]->lprior, states[i]->llikelihood);

    }
}

void update_model_parameters(state *s1, int mode, int narrow_emis, double dir_prior_weight) {
    static char rname[]="update_model_parameters";
    trans *temp_t;
    int trans_only;

    if (!strcmp(s1->label, "end")) return;
    else if (!strcmp(s1->label, "start")) trans_only = 1;
    else trans_only = 0;

    if (dir_prior_weight > 0) set_state_prior_adjustment(s1, trans_only, narrow_emis, dir_prior_weight);
    set_state_parameters(s1, mode, trans_only, narrow_emis);

    /* Update parent transition settings into state */
    trans_only = 1;
    temp_t = s1->in;
    while (temp_t != NULL) {
        if (temp_t->source != s1) {
            if (dir_prior_weight > 0) set_state_prior_adjustment(temp_t->source, trans_only, narrow_emis, dir_prior_weight);
            set_state_parameters(temp_t->source, mode, trans_only, narrow_emis);
        }
        temp_t = temp_t->next_dest;
    }
}

void update_model_priors_and_likelihoods(state *s1, int num_states, int mode, int use_prior, int narrow_emis) {
    static char rname[]="update_model_priors_and_likelihoods";
    trans *temp_t;
    state *s2;

    if (use_prior > 0) s1->lprior = compute_state_prior(s1, num_states, mode, use_prior, narrow_emis);
    else s1->lprior = 0.0;

    s1->llikelihood = compute_state_likelihood(s1, num_states, mode, narrow_emis);

    if (DEBUG) fprintf(stderr, "New state %d (%s): lprior = %f, llikelihood = %f\n", s1->id, s1->label, s1->lprior, s1->llikelihood);

    /* Update parents of state */
    temp_t = s1->in;
    while (temp_t != NULL) {
        s2 = temp_t->source;

        if (s2 != s1) {

            if (use_prior > 0) s2->lprior = compute_state_prior(s2, num_states, mode, use_prior, narrow_emis);
            else s2->lprior = 0.0;

            s2->llikelihood = compute_state_likelihood(s2, num_states, mode, narrow_emis);

            if (DEBUG) fprintf(stderr, "Parent state %d (%s): lprior = %f, llikelihood = %f\n", s2->id, s2->label, s2->lprior, s2->llikelihood);
        }

        temp_t = temp_t->next_dest;
    }
}

void print_observation_string(shead *start) {
    sdata *next;

    next = start->string;   
    while (next != NULL) {
        fprintf(stderr, "word: %s label %s\n", next->word, next->label);
        next = next->next;
    }
}

void run_initial_model_exit(char *output_dir, char *vocab_file, state **states, int num_states, char *emissions_dir, int smooth, int mode, int narrow_emis, int iteration, double dir_prior_weight) {

    fprintf(stderr, "Printing model to file and exiting...\n");
    set_model_parameters(states, num_states, mode, narrow_emis, dir_prior_weight);
    print_model_to_file(NULL, output_dir, vocab_file, states, num_states, emissions_dir, smooth, mode, iteration);
    exit(1);
}

void evaluate_intermediate_model(char *obs_file_path, char *output_dir, state *initial, state **states, int num_states, int vocab_size, int smooth, int mode, int iteration) {
    FILE *string_file, *hyp_file, *pp_file;
    char hyp_file_path[1000], pp_file_path[1000];
    path_head   *vit_path;
    shead       *start;
    double zeroton_lprob, *double_probs, total_obs_logprob,  total_t_logprob, total_e_logprob;
    double pe_logprob, pt_logprob, pp_e, pp_t, pf_logprob, pp;
    int i, j, total_obs;
    int string_count = 0;
    int obs_with_id = 0;
    int read_label = 0;
    int vit_details = 0;
    int punc_trans = 0;
    int print_state_id = 0;
    int print_probs = 0;
    float trans_weight = 1.0;

    /* Open observation file */
    string_file = kopen_r(obs_file_path);

    /* Open output files */
    sprintf(hyp_file_path, "%s/iter_%d_%d.hyp.gz", output_dir, iteration, num_states);
    hyp_file = kopen_wgz(hyp_file_path);

    sprintf(pp_file_path, "%s/pp.scores", output_dir);
    pp_file = kopen_a(pp_file_path);

    /* Create smoothed probability distributions */
    for (i = 0; i < num_states; i++) {
        if (states[i]->O->counts == NULL) continue;

        states[i]->O->lprobs = (float *) kmalloc((vocab_size + 1)*sizeof(float));

        double_probs = smooth_counts(states[i]->O, vocab_size, smooth, mode, &zeroton_lprob);

        /* Cast probs to floats */
        for (j = 0; j <= vocab_size; j++) {
            if (double_probs[j] == 0) states[i]->O->lprobs[j] = no_prob;
            else states[i]->O->lprobs[j] = (float) log(double_probs[j]);
        }
        free(double_probs);
    }

    /* Run Viterbi over the observation file */
    fprintf(stderr, "Processing observations - printing one '.' for each 100 strings processed\n");

    total_obs = 0;
    total_obs_logprob = 0.0;
    total_e_logprob = 0.0;
    total_t_logprob = 0.0;

    /* Read first line from string file */
    start = read_one_string(string_file, obs_with_id, read_label, 0);

    /* Tag strings as long as there are strings in file */
    while (start != NULL) {

        /* Find Viterbi path for string */
        vit_path = find_vit_path(start, initial, vit_details, punc_trans, read_label, trans_weight);
        print_path_to_file(vit_path, hyp_file, start, obs_with_id, print_state_id, print_probs);
        collect_pp_stats(vit_path, &total_obs, &total_obs_logprob, &total_t_logprob, &total_e_logprob);

        free_path(vit_path);
        fflush(hyp_file);
        free_string(start);
        string_count++;

        if ((string_count % 100) == 0) fprintf(stderr, ".");
        if ((string_count % 1000) == 0) fprintf(stderr, "\n");

        /* Read in next line from string file */
        start = read_one_string(string_file, obs_with_id, read_label, 0);

    }
    fclose(hyp_file);
    fclose(string_file);

    pe_logprob = total_e_logprob / total_obs;
    pp_e = exp(-pe_logprob);

    pt_logprob = total_t_logprob / total_obs;
    pp_t = exp(-pt_logprob);

    pf_logprob = total_obs_logprob / total_obs;
    pp = exp(-pf_logprob);

    fprintf(pp_file, "%d %.2f\n", num_states, pp);
    if (DEBUG) fprintf(stderr, "    transition part: %f (%f / %d) PP = %.2f\n", pt_logprob, total_t_logprob, total_obs, pp_t);
    if (DEBUG) fprintf(stderr, "    emission part: %f (%f / %d) PP = %.2f\n", pe_logprob, total_e_logprob, total_obs, pp_e);
    fclose(pp_file);

    fprintf(stderr, "DONE!\n");

    /* Free probability distributions */
    for (i = 0; i < num_states; i++) {
        free(states[i]->O->lprobs);
        states[i]->O->lprobs = NULL;
    }
}


int path_already_exists(state *initial, shead *current_string) {
    static char rname[]="path_already_exists";
    state *next_state;
    char word[200];
    trans *current_trans;
    int word_id, found, string_length, step, i, already_on_list;
    sdata *current_string_position;
    ts **possible, *candidate, *new_candidate;

    /* Checks to see if the current string is identical to a string that has
        already been incorporated into the model. If so, the initial transition
        count from the start state should be updated (but not here). */

    /* string length includes the end state but not the start state */
    string_length = count_num_symbols(current_string);
    if (DEBUG) fprintf(stderr, "string length = %d\n", string_length);

    possible = (ts **) kmalloc((string_length+1)*sizeof(ts *));
    for (i = 0; i <= string_length; i++) {
        possible[i] = NULL;
    }

    possible[0] = (ts *) kmalloc(sizeof(ts));
    possible[0]->s1 = initial;
    possible[0]->prev = NULL;
    possible[0]->next = NULL;

    step = 1;
    current_string_position = current_string->string;
    while (current_string_position != NULL) {

        strcpy(word, current_string_position->word);
        remove_punc(word);
        word_id = bow_word2int_no_add(word);
        if (word_id == -1) {

            /* Free linked state structure */
            for (i = 0; i <= string_length; i++) {
                candidate = possible[i];
                while (candidate != NULL) {
                    new_candidate = candidate->next;
                    free(candidate);
                    candidate = new_candidate;
                }
            }
            free(possible);

            return(0);
        }

        if (DEBUG) fprintf(stderr, "current word = %s\n", word);

        /* See if there is a transition from a possible previous state into a state that emits the given word */
        found = 0;
        candidate = possible[step-1];
        while (candidate != NULL) {
            current_trans = candidate->s1->out;
            while (current_trans != NULL) {
                next_state = current_trans->dest;

                if (DEBUG) fprintf(stderr, "   considering transition from state %d (%s) into state %d (%s)\n", candidate->s1->id, candidate->s1->label, next_state->id, next_state->label);

                if (current_string_position->label != NULL) {
                    if (strcmp(current_string_position->label, next_state->label)) {
                        current_trans = current_trans->next_source;
                        continue;
                    }
                }

                if (retrieve_count(next_state->O, word_id) > 0) {
                    found = 1;

                    if (DEBUG) fprintf(stderr, "      found match...\n");

                    /* See if this state is already on the list of possible states at this time step */
                    new_candidate = possible[step];
                    already_on_list = 0;
                    while (new_candidate != NULL) {
                        if (new_candidate->s1->id == next_state->id ) {
                            already_on_list = 1;
                            if (DEBUG) fprintf(stderr, "         state is already on list\n");
                            break;
                        }
                        new_candidate = new_candidate->next;
                    }

                    /* Add this state to the list of possible states at this time step */
                    if (!already_on_list) {
                        if (DEBUG) fprintf(stderr, "         adding state to list\n");
                        new_candidate = possible[step];
                        if (new_candidate == NULL) {
                            possible[step] = (ts *) kmalloc(sizeof(ts));
                            possible[step]->s1 = next_state;

                            /* It's ok to do this, because there won't be two ways 
                               to get to one state in the initial model. Otherwise,
                               there may need to be backpointers to multiple states */
                            possible[step]->prev = candidate;
                            possible[step]->next = NULL;
                        }
                        else {
                            while (new_candidate->next != NULL) {
                                new_candidate = new_candidate->next;
                            }
                            new_candidate->next = (ts *) kmalloc(sizeof(ts));
                            new_candidate->next->s1 = next_state;
                            new_candidate->next->prev = candidate;
                            new_candidate->next->next = NULL;
                        }
                    }
                }
                current_trans = current_trans->next_source;
            }
            candidate = candidate->next;
        }

        if (!found) {
            /* Free linked state structure */
            for (i = 0; i <= string_length; i++) {
                candidate = possible[i];
                while (candidate != NULL) {
                    new_candidate = candidate->next;
                    free(candidate);
                    candidate = new_candidate;
                }
            }
            free(possible);
            return(0);
        }
        current_string_position = current_string_position->next;
        step++;
    }

    if (DEBUG) fprintf(stderr, "Looking for a transition into the end state\n");

    if (step != string_length) quit(-1, "%s: step = %d, string_length = %d\n", rname, step, string_length);

    /* Make sure there is a transition into the end state */
    found = 0;
    candidate = possible[string_length-1];
    while (candidate != NULL) {
        current_trans = candidate->s1->out;
        while (current_trans != NULL) {
            next_state = current_trans->dest;

            if (!strcmp(next_state->label, "end")) {
                found = 1;
                possible[string_length] = (ts *) kmalloc(sizeof(ts));
                possible[string_length]->s1 = next_state;
                possible[string_length]->prev = candidate;
                possible[string_length]->next = NULL;

                if (DEBUG) fprintf(stderr, "    found transition into end state...\n");
                break;
            }

            current_trans = current_trans->next_source;
        }
        if (found) break;
        candidate = candidate->next;
    }

    if (!found) {
        /* Free linked state structure */
        for (i = 0; i <= string_length; i++) {
            candidate = possible[i];
            while (candidate != NULL) {
                new_candidate = candidate->next;
                free(candidate);
                candidate = new_candidate;
            }
        }
        free(possible);
        return(0);
    }

    /* Print backwards path through model */
    if (DEBUG) fprintf(stderr, "Path through the model = \n");
    candidate = possible[string_length];
    while (candidate != NULL && strcmp(candidate->s1->label, "start")) {
        if (DEBUG) fprintf(stderr, "    state %d (%s) \n", candidate->s1->id, candidate->s1->label);

        if (strcmp(candidate->s1->label, "end")) update_emission_count(candidate->s1, current_string->count);
        update_transition_count(candidate->prev->s1, candidate->s1, current_string->count);

        candidate = candidate->prev;
    }

    /* Free linked state structure */
    for (i = 0; i <= string_length; i++) {
        candidate = possible[i];
        while (candidate != NULL) {
            new_candidate = candidate->next;
            free(candidate);
            candidate = new_candidate;
        }
    }
    free(possible);

}

void update_emission_count(state *s1, int count) {
    static char rname[]="update_emission_count";
    tc *current;

    current = s1->O->counts;
    while (current != NULL) {
        current->count += count;
        current = current->next;
    }

    if (DEBUG) fprintf(stderr, "Updated emission counts from state %d (%s) by %d\n", s1->id, s1->label, count);

}

void update_transition_count(state *from_state, state *to_state, int count) {
    static char rname[]="update_transition_count";
    trans *new_trans, *current;

    current = from_state->out;
    while (current != NULL) {
        if (current->source == from_state && current->dest == to_state) {
            current->count += count;
            if (DEBUG) fprintf(stderr, "Updated transition count between states %d (%s) and %d (%s) by %d\n", from_state->id, from_state->label, to_state->id, to_state->label, count);
            return;
        }
        current = current->next_source;
    }

    quit(-1, "%s: Couldn't find transition between states %d (%s) and %d (%s)\n", rname, from_state->id, from_state->label, to_state->id, to_state->label);
}

void add_incremental_strings(FILE *string_file, int *ptr_num_states, state *initial, state *end, int read_id, int read_label, int read_count, int *ptr_max_state_label, int num_to_add, count_dist  **priors, int num_prior_dists, int vocab_size, double trans_uni_alpha, double obs_uni_alpha, int *ptr_more_strings, int *ptr_num_added) {
    static char rname[]="add_incremental_strings";
    int i, num_added, done;
    shead *new_string;

    fprintf(stderr, "Incremental: adding more strings...");
    num_added = 0;
    done = 0;
    for (i = 0; i < num_to_add; i++) {

        if (feof(string_file)) {
            done = 1;
            break;
        }

        new_string = read_one_string(string_file, read_id, read_label, read_count);

        if (new_string == NULL) { 
            done = 1;
            break; 
        }

        map_words_to_unk(new_string);

        if (!path_already_exists(initial, new_string)) {
            add_string_to_model(new_string, initial, end, ptr_num_states, ptr_max_state_label, priors, num_prior_dists, vocab_size, trans_uni_alpha, obs_uni_alpha);
            fprintf(stderr, "Added new string to model\n");
        }
        else fprintf(stderr, "New string merged with existing model\n");
        free_string(new_string);
        num_added++;
    }

    if (done) {
        fclose(string_file);
        *ptr_more_strings = 0;
        *ptr_num_added = num_added;
        fprintf(stderr, "All strings read...\n");
    }
    else {
        fprintf(stderr, "%d strings added\n", num_added);
        *ptr_num_added = num_added;
    }
}







