/* bw.c */

/* Invalid transitions are designated by a prob of -1.0 */

#include "general.h"

state *iterate_bw(shead **strings, int num_strings, state *hmm, int num_states, int details, int punc_trans, int uniform, int random, int vocab_size, int trans_only, char *emissions_dir, FILE *outfile) {
    static char rname[]="iterate_bw";
    int i, j, k, iter_num, last_state_index, first_state_index, converged, num_obs, total_obs;
    state **states, *old_model, *new_model, *first_state, *last_state;
    double **aij, **alphas, **betas, *scale;
    double **num_aij_local, **num_aij_global, **num_bj_local, **num_bj_global, *denom_local, *denom_global;
    double new_pf_logprob, total_obs_logprob, old_pf_logprob, stop_ratio;
    double total_prob, string_logprob, pp;
    char **obs;
    sdata *current_data;
    float **bj;

    stop_ratio = 0.005;

    iter_num = 0;
    old_model = hmm;
    converged = 0;
    total_obs = 0;
    new_pf_logprob = 0.0;
    old_pf_logprob = 0.0;

    /******************************/
    /* Initialize data structures */
    /******************************/

    states = collect_states(old_model, &num_states);

    /* Identify first and last states of model */
    first_state = old_model;
    first_state_index = get_state_index(states, first_state->id, num_states);

    last_state_index = find_last_state_index(states, num_states);
    if (last_state_index == -1) quit(-1, "%s: last state was not located...\n", rname);
    last_state = states[last_state_index];


    /* Create transition parameter structures */
    aij = (double **) kmalloc(num_states * sizeof(double *));
    for (i = 0; i < num_states; i++) {
        aij[i] = (double *) kmalloc(num_states * sizeof(double));
    }

    /* Set in initial parameter estimates */
    if (uniform) {
        set_uniform_model_parameters(aij, num_states, states);
    }
    else if (random) {
        set_random_model_parameters(aij, num_states, states, random);
    }
    else {
        fill_in_trans_probs(aij, num_states, states);
    }
    if (details) print_trans(aij, num_states, states);

    /* Create emission parameter structures */
    bj = set_emission_probs(num_states, vocab_size, states);
 
    num_aij_local = (double **) kmalloc(num_states * sizeof(double *));
    for (i = 0; i < num_states; i++) {
        num_aij_local[i] = (double *) kmalloc(num_states * sizeof(double));
    }

    num_aij_global = (double **) kmalloc(num_states * sizeof(double *));
    for (i = 0; i < num_states; i++) {
        num_aij_global[i] = (double *) kmalloc(num_states * sizeof(double));
    }

    if (!trans_only) {
        num_bj_local = (double **) kmalloc(num_states * sizeof(double *));
        for (i = 0; i < num_states; i++) {
            num_bj_local[i] = (double *) kmalloc((vocab_size+1) * sizeof(double));
        }

        num_bj_global = (double **) kmalloc(num_states * sizeof(double *));
        for (i = 0; i < num_states; i++) {
            num_bj_global[i] = (double *) kmalloc((vocab_size+1) * sizeof(double));
        }
    }

    denom_local = (double *) kmalloc(num_states * sizeof(double));
    denom_global = (double *) kmalloc(num_states * sizeof(double));

    /**********************/
    /* Iterate Baum-Welch */
    /**********************/

    while (!converged) {
        iter_num++;

        fprintf(stderr, "ITER NUM = %d:", iter_num);

        initialize_double(num_aij_global, num_states, num_states);
        if (!trans_only) initialize_double(num_bj_global, num_states, (vocab_size+1));
        for (i = 0; i < num_states; i++) {
            denom_global[i] = 0.0;
        }
        total_obs_logprob = 0.0;
        total_obs = 0;

        /* Iterate over individual observation strings */
        for (k = 0; k < num_strings; k++) {

            if ((k % 10) == 0) fprintf(stderr, ".");

            if (strings[k]->string == NULL) quit(-1, "%s: no data in string, k = %d\n", rname, k);
            num_obs = count_num_symbols(strings[k]);
            total_obs += (num_obs-1);

            /******************************/
            /* Initialize local variables */
            /******************************/

            alphas = (double **) kmalloc((num_obs + 1) * sizeof(double *));
            for (i = 0; i <= num_obs; i++) {
                alphas[i] = (double *) kmalloc(num_states * sizeof(double));
            }
            initialize_double(alphas, num_obs + 1, num_states);
    
            betas = (double **) kmalloc((num_obs + 1) * sizeof(double *));
            for (i = 0; i <= num_obs; i++) {
                betas[i] = (double *) kmalloc(num_states * sizeof(double));
            }
            initialize_double(betas, num_obs + 1, num_states);
    
            scale = (double *) kmalloc((num_obs + 1) * sizeof(double));
            for (i = 0; i <= num_obs; i++) {
                scale[i] = 1.0;
            }

            initialize_double(num_aij_local, num_states, num_states);
            if (!trans_only) initialize_double(num_bj_local, num_states, (vocab_size+1));
            for (i = 0; i < num_states; i++) {
                denom_local[i] = 0.0;
            }

            /* Collect all the observation words into an array */
            obs = (char **) kmalloc((num_obs + 1) * sizeof(char *));
            obs[0] = NULL;
            current_data = strings[k]->string;
            for(i = 1; i < num_obs; i++) {
                if (current_data == NULL) quit(-1, "%s: unexpected null observation, i = %d\n", rname, i);
                if (current_data->word == NULL) quit(-1, "%s: unexpected null observation word, i = %d\n", rname, i);
                obs[i] = strdup(current_data->word);
                current_data = current_data->next;
            }
            obs[num_obs] = NULL;

            /******************************/
            /* Calculate alphas and betas */
            /******************************/

            /* Forward pass: compute the alphas */
            compute_alphas(alphas, aij, bj, scale, obs, num_obs, states, num_states, first_state_index, last_state_index, punc_trans, vocab_size);

            if (details) print_alphas(alphas, obs, num_obs, states, num_states, scale);

            /* Backward pass: compute the betas */
            compute_betas(betas, aij, bj, scale, obs, num_obs, states, num_states, first_state_index, last_state_index, punc_trans, vocab_size);
/*
            if (details) print_betas(betas, obs, num_obs, states, num_states);
*/

            /* Update aij's and bj's - hmm parameters */
            update_parameters(num_aij_local, num_bj_local, denom_local, states, num_states, alphas, betas, scale, obs, num_obs, aij, bj, first_state_index, last_state_index, details, vocab_size, trans_only);

            /* Add num_aij_local, num_bj_local and denom_local to global parts */
            for (i = 0; i < num_states; i++) {
                denom_global[i] += denom_local[i];
                for (j = 0; j < num_states; j++) {
                    num_aij_global[i][j] += num_aij_local[i][j];
                    /* P(O) = 1.0; */
                    /* fprintf(stderr, "i = %d, j = %d, num_aij_global[i][j] = %f\n", i,j, num_aij_global[i][j]); */
                }

                if (!trans_only) {
                    for (j = 0; j <= vocab_size; j++) {
                        num_bj_global[i][j] += num_bj_local[i][j];
                    }
                }
            }

            /* Compile the total observation probability */
            string_logprob = 0.0;
            for (i = 1; i <= num_obs; i++) {
                string_logprob += log(scale[i]);
            }
            total_obs_logprob += string_logprob;

            /* Free alphas, betas */
            for (i = 0; i <= num_obs; i++) {
                free(alphas[i]);
            }
            free(alphas);

            for (i = 0; i <= num_obs; i++) {
                free(betas[i]);
            }
            free(betas);
            free(scale);
            for(i = 1; i < num_obs; i++) {
                free(obs[i]);
            }
            free(obs);
        }
        fprintf(stderr, "\n");

        /* Divide num_aij_global and num_bj_local by denom_global to get new parameter estimates */
        for (i = 0; i < num_states; i++) {
            total_prob = 0.0;
            for (j = 0; j < num_states; j++) {
                aij[i][j] = num_aij_global[i][j] / denom_global[i];
                total_prob += aij[i][j];
            }
            if (fabs(total_prob - 1.0) > 0.00001) quit(-1, "%s: Total transition prob out of state %d (%s) sums to %f -- not 1\n", rname, i, states[i]->label, total_prob);

            if (!trans_only)  {
                /* Update emissions for all but the start and end states */
                if (!(!strcmp(states[i]->label,"end") ||
                    !strcmp(states[i]->label,"start"))) {

                    total_prob = 0.0;
                    for (j = 0; j <= vocab_size; j++) {
                        bj[i][j] = num_bj_global[i][j] / denom_global[i];

                        /* May need to put a weird data check here... */

                        total_prob += bj[i][j];
                    }
                    if (fabs(total_prob - 1.0) > 0.00001) quit(-1, "%s: Total emission prob out of state %d (%s) sums to %f -- not 1\n", rname, i, states[i]->label, total_prob);
                }
            }
        }
        /* if (details) print_trans(aij, num_states, states);  */

        new_pf_logprob = total_obs_logprob / total_obs;
        pp = exp(-new_pf_logprob);
        fprintf(stderr, "ITER %d: TOTAL OBS LOGPROB: %f (%f / %d) PP = %.2f\n", iter_num, new_pf_logprob, total_obs_logprob, total_obs, pp);
        if (iter_num == 1 || fabs(new_pf_logprob - old_pf_logprob) > stop_ratio)
            old_pf_logprob = new_pf_logprob;
        else converged = 1;
    }

    if (details) print_trans(aij, num_states, states);

    /* Redundant check here... */
    for (i = 0; i < num_states; i++) {
        if (!(!strcmp(states[i]->label,"end") ||
              !strcmp(states[i]->label,"e"))) {

            total_prob = 0.0;
            for (j = 0; j < num_states; j++) {
                total_prob += aij[i][j];
            }
            if (fabs(total_prob - 1.0) > 0.00001) quit(-1, "%s: extra check - Total transition prob out of state %d (%s) sums to %f -- not 1\n", rname, i, states[i]->label, total_prob);
        }
    }

    /* Copy new parameters to new model */
    copy_model_trans(states, num_states, aij);

    /* If an emissions directory is given, print new distributions */
    if (emissions_dir != NULL) {
        fprintf(stderr, "Printing to %s\n", emissions_dir);
        print_emission_dists(emissions_dir, bj, num_states, vocab_size, states);
        print_reestimated_model_to_file(outfile, states, num_states, emissions_dir);
    }

    /* Free things here... */
    free(states);
    for (i = 0; i < num_states; i++) {
        free(aij[i]);
    }
    free(aij);

    for (i = 0; i < num_states; i++) {
        free(bj[i]);
    }
    free(bj);

    for (i = 0; i < num_states; i++) {
        free(num_aij_local[i]);
        if (!trans_only) free(num_bj_local[i]);
    }
    free(num_aij_local);
    if (!trans_only) free(num_bj_local);

    for (i = 0; i < num_states; i++) {
        free(num_aij_global[i]);     
        if (!trans_only) free(num_bj_global[i]);
    }
    free(num_aij_global);
    if (!trans_only) free(num_bj_global);

    free(denom_local);
    free(denom_global);

    fprintf(stderr, "\nTrained model in %d iterations\n", iter_num);

    return (first_state);
}

void set_uniform_model_parameters(double **aij, int num_states, state **states) {
    static char rname[]="set_uniform_model_parameters";
    int i, j, count;
    double prob;
  
    fill_in_trans_probs(aij, num_states, states);
    for (i = 0; i < num_states; i++) {
        count = 0;
        for (j = 0; j < num_states; j++) {
            if (aij[i][j] > 0) {
                count++;
            }
        }
        prob = (double) 1.0 / count;
        for (j = 0; j < num_states; j++) {
            if (aij[i][j] > 0) {
                aij[i][j] = (double) 1.0 / count;
            }
        }
    }
}

void set_random_model_parameters(double **aij, int num_states, state **states, int random) {
    static char rname[]="set_random_model_parameters";
    int i,j;
    double total_prob;

    fill_in_trans_probs(aij, num_states, states);
    for (i = 0; i < num_states; i++) {
        total_prob = 0.0;
        for (j = 0; j < num_states; j++) {
            if (aij[i][j] > 0) {
                aij[i][j] = ran0(&random);
                total_prob += aij[i][j];
            }
        }
        for (j = 0; j < num_states; j++) {
            if (aij[i][j] > 0) {
                aij[i][j] /= total_prob; /* Renormalize probs */
            }
        }
    }
}

void fill_in_trans_probs(double **aij, int num_states, state **states) {
    static char rname[]="fill_in_trans_probs";
    int i, source_index, dest_index;
    double trans_prob;
    trans *current_trans;

    initialize_double(aij, num_states, num_states);
    for(i = 0; i < num_states; i++) {
        if (states[i] == NULL) continue;
        current_trans = states[i]->out;
        if (current_trans == NULL) continue;
        source_index = get_state_index(states, current_trans->source->id, num_states);
        if (source_index == -1) quit(-1, "%s: state %s was not found on state list\n", rname, states[i]->label);
        while (current_trans != NULL) {
            dest_index = get_state_index(states, current_trans->dest->id, num_states);
            if (dest_index == -1) quit(-1, "%s: state %s was not found on state list\n", rname, states[i]->label);
            trans_prob = current_trans->prob;

            aij[source_index][dest_index] = trans_prob;

            current_trans = current_trans->next_source;
        }
    }
}

void print_trans(double **var, int num_states, state **states) {
    static char rname[]="print_trans";
    int i, j;

    fprintf(stderr, "New transition estimates:\n");
    for (i = 0; i < num_states; i++) {
        fprintf(stderr, "%s -> \n", states[i]->label);
        for (j = 0; j < num_states; j++) {
            if (var[i][j] > 0) 
                fprintf(stderr, "    %s: %.3f\n", states[j]->label, var[i][j]);
        }
    }
}

void initialize_double(double **var, int tot_i, int tot_j) {
    static char rname[]="initialize_double";
    int i, j;

    for (i = 0; i < tot_i; i++) {
        for (j = 0; j < tot_j; j++) {
            var[i][j] = 0.0;
        }
    }
}


void compute_alphas(double **alphas, double **aij_old, float **bj, double *scale, char **obs, int num_obs, state **states, int num_states, int first_state_index, int last_state_index, int punc_trans, int vocab_size) {
    static char rname[]="compute_alphas";
    char *word, *prev_word;
    int current_step, i, j, allow_transition;
    sdata *current_data;
    double b_j, max_alpha;

    word = (char *) kmalloc(500*sizeof(char));
    prev_word = (char *) kmalloc(500*sizeof(char));

    /* Step 1 - Initialization */
    /* ----------------------- */

    current_step = 1;
    alphas[current_step-1][first_state_index] = 1.0;
    strcpy(prev_word, obs[current_step]);
    get_word(obs[current_step], word);

    /* Find non-zero transitions out of initial state, and which states could emit word */
    i = first_state_index;
    for (j = 0; j < num_states; j++) {
        if (aij_old[i][j] > 0) { 
            b_j = get_bj_prob(bj, states, j, word, vocab_size);
            alphas[current_step][j] = aij_old[i][j] * b_j;
        }
    }

    /* Set scaling factor */
    max_alpha = 0.0;
    for (j = 0; j < num_states; j++) {
        if (alphas[current_step][j] > max_alpha) max_alpha = alphas[current_step][j];
    }    
    for (j = 0; j < num_states; j++) {
        alphas[current_step][j] /= max_alpha;
    }
    scale[current_step] = max_alpha;


    /* Step 2 - Recursion */
    /* ------------------ */

    /* Get next word in data string */
    for (current_step = 2; current_step < num_obs; current_step++) {
        if (punc_trans) {
            allow_transition = ends_in_punc(prev_word);
            if (DEBUG) fprintf(stderr, "Prev word = %s, allow_transitions = %d\n", prev_word, allow_transition);
        }
        strcpy(prev_word,obs[current_step]);
        get_word(obs[current_step],word);

        for (j = 0; j < num_states; j++) {
            for (i = 0; i < num_states; i++) {
                if (punc_trans && !allow_transition) {
                    if (i != j) continue;
                }
                if (alphas[current_step-1][i] > 0 && aij_old[i][j] > 0) {
                    alphas[current_step][j] += (alphas[current_step-1][i] * aij_old[i][j]);
                }
            }
            b_j = get_bj_prob(bj, states, j, word, vocab_size);
            alphas[current_step][j] *= b_j;
        }

        /* Set scaling factor */
        max_alpha = 0.0;
        for (j = 0; j < num_states; j++) {
            if (alphas[current_step][j] > max_alpha) max_alpha = alphas[current_step][j];
        }
        for (j = 0; j < num_states; j++) {
            alphas[current_step][j] /= max_alpha;
        }
        scale[current_step] = max_alpha;

    }

    /* Deal with last transition into end state with no emission */
    current_step = num_obs;
    j = last_state_index;
    for (i = 0; i < num_states; i++) {
        if (alphas[current_step-1][i] > 0 && aij_old[i][j] > 0) {
            alphas[current_step][j] += alphas[current_step-1][i] * aij_old[i][j];
        }
    }

    /* Set scaling factor */
    scale[current_step] = alphas[current_step][j];
    alphas[current_step][j] = 1.0; /* scaled alpha is divided by itself */

    free(word);
    free(prev_word);
}

void print_alphas(double **alphas, char **obs, int num_obs, state **states, int num_states, double *scale) {
    int i,current_step;

    for (current_step = 0; current_step <= num_obs; current_step++) {
        fprintf(stderr, "Alphas of observation %d, word = %s, scale = %f\n", current_step, obs[current_step], scale[current_step]);
        for (i = 0; i < num_states; i++) {
            if (alphas[current_step][i] > 0) {
                fprintf(stderr, "   state = %s, alpha = %.20f\n", states[i]->label, alphas[current_step][i]);
            }
        }
    }
}


void compute_betas(double **betas, double **aij_old, float **bj, double *scale, char **obs, int num_obs, state **states, int num_states, int first_state_index, int last_state_index, int punc_trans, int vocab_size) {
    static char rname[]="compute_betas";
    char *word, *prev_word;
    int current_step, i, j, allow_transition;
    double b_j;

    word = (char *) kmalloc(500*sizeof(char));

    /* Step 1 - Initialization */
    /* ----------------------- */

    current_step = num_obs;
    i = last_state_index;
    betas[current_step][i] = 1.0/scale[current_step];

    /* Step 2 - Recursion */
    /* ------------------ */

    /* Get next word in data string */
    for (current_step = num_obs-1; current_step >= 0; current_step--) {

        if (current_step == (num_obs - 1) || current_step == 0) allow_transition = 1;
        else allow_transition = ends_in_punc(obs[current_step]);

        if ((current_step + 1) < num_obs) get_word(obs[current_step+1], word);

        for (j = 0; j < num_states; j++) {
            if (betas[current_step+1][j] > 0) {

                /* fprintf(stderr, "current_step = %d, j = %d, betas[current_step+1][j] = %f\n", current_step, j, betas[current_step+1][j]); */

                if ((current_step + 1) == num_obs) {
                    b_j = 1.0;
                }
                else {
                    b_j = get_bj_prob(bj, states, j, word, vocab_size);
                    if (b_j == 0.0) continue;
                }
                for (i = 0; i < num_states; i++) {

                    /* fprintf(stderr, "   i = %d, aij = %f, b_j = %f\n", i, aij_old[i][j], b_j); */

                    /* See if a transition between source and dest is allowed */
                    if (current_step > 0 && punc_trans && !allow_transition && i != j) {
                        /* Don't allow this transition */
                        continue;
                    }

                    /* If we're at step 0, only accept transitions into the start state */
                    if (current_step == 0 && i != first_state_index) {
                        /* Don't allow this transition */
                        continue;
                    }

                    if (aij_old[i][j] > 0) {
                        betas[current_step][i] += aij_old[i][j] * b_j * betas[current_step+1][j];
                        /* fprintf(stderr, "      betas[current_step][i] += %f\n", aij_old[i][j] * b_j * betas[current_step+1][j]); */
                    }
                }
            }
        }

        /* Incorporate scaling factor */
        for (i = 0; i < num_states; i++) {
            betas[current_step][i] /= scale[current_step];
        }
    }
    free(word);
}

void print_betas(double **betas, char **obs, int num_obs, state **states, int num_states) {
    int i,current_step;

    for (current_step = num_obs; current_step >= 0; current_step--) {
        fprintf(stderr, "Betas of observation %d, word = %s\n", current_step, obs[current_step]);
        for (i = 0; i < num_states; i++) {
            if (betas[current_step][i] > 0) {
                fprintf(stderr, "   state = %s, beta = %f\n", states[i]->label, betas[current_step][i]);
            }
        }
    }
}

int find_last_state_index(state **states, int num_states) {
    static char rname[]="find_last_state_index";
    int i;

    for (i = 0; i < num_states; i++) {
        if (!strcmp(states[i]->label,"end") || !strcmp(states[i]->label,"e")) {
            return(i);
        }
    }
    return(-1);
}

void copy_model_trans(state **states, int num_states, double **aij) {
    static char rname[]="copy_model_trans";
    int i, j, from_state_index, to_state_index;
    trans *current_trans;
    double total_prob;

    for (i = 0; i < num_states; i++) {
        total_prob = 0.0;
        current_trans = states[i]->out;
        while (current_trans != NULL) {
            from_state_index = get_state_index(states, current_trans->source->id, num_states);
            to_state_index = get_state_index(states, current_trans->dest->id, num_states);
            current_trans->prob = aij[from_state_index][to_state_index];
            total_prob += current_trans->prob;
            current_trans = current_trans->next_source;
        }

        if (strcmp(states[i]->label, "end") && strcmp(states[i]->label, "e")) {
            if (fabs(total_prob - 1.0) > 0.00001)
                quit(-1, "%s: new transition probs do not sum to 1 for state index %d, label %s: total_prob = %f\n", rname, i, states[i]->label, total_prob);
        }
    }
}

void update_parameters(double **num_aij_local, double **num_bj_local, double *denom_local, state **states, int num_states, double **alphas, double **betas, double *scale, char **obs, int num_obs, double **aij_old, float **bj, int first_state_index, int last_state_index, int details, int vocab_size, int trans_only) {
    static char rname[]="update_parameters";
    char *word;
    int i, j, current_step, word_id;
    double b_j;

    word = (char *) kmalloc(500*sizeof(char));

    /* Update aij numerator */
    for (i = 0; i < num_states; i++) {
        for (j = 0; j < num_states; j++) {
            for (current_step = 0; current_step < num_obs; current_step++) {
                if (current_step == num_obs - 1)  {
                    b_j = 1.0;
                    num_aij_local[i][j] += alphas[current_step][i] * aij_old[i][j] * b_j * betas[current_step+1][j];
                }
                else {
                    get_word(obs[current_step+1], word);
                    b_j = get_bj_prob(bj, states, j, word, vocab_size);
                    num_aij_local[i][j] += alphas[current_step][i] * aij_old[i][j] * b_j * betas[current_step+1][j];
                }
            }
        }
    }

    if (!trans_only) {
        /* Update bj numerator */
        for (current_step = 1; current_step < num_obs; current_step++) {
            get_word(obs[current_step], word);
            word_id = bow_word2int_no_add(word);

            if ((word_id == -1) || (word_id > vocab_size)) { /* Word does not exist in distribution */
                word_id = bow_word2int_no_add(unk_word);
                if ((word_id == -1) || (word_id > vocab_size)) {
                    quit(-1, "%s: can't get reasonable word id for word %s\n", rname, word);
                }
            }

            for (i = 0; i < num_states; i++) {
                num_bj_local[i][word_id] += (alphas[current_step][i] * betas[current_step][i] * scale[current_step]);
            }
        }
    }

    /* Update denominator */
    for (i = 0; i < num_states; i++) {
        for (current_step = 0; current_step < num_obs; current_step++) {
            /* There is an extra individual scale term in here... scale[i] */
            denom_local[i] += (alphas[current_step][i] * betas[current_step][i] * scale[current_step]);
        }
    }
    free(word);
}

void get_word(char *orig_word, char *word) {
    static char rname[]="get_word";
    int word_id;

    strcpy(word, orig_word);
    remove_punc(word);
    word_id = bow_word2int_no_add(word);
    if (word_id == -1) { /* Word does not exist in distribution */
        strcpy(word, unk_word);
    }
}

double ran0(int *idum) {
    static char rname[]="ran0";
    static double y,maxran,v[98];
    double dum;
    static int iff=0;
    int j;
    unsigned i,k;
    void nrerror();

    if (*idum < 0 || iff == 0) {
        iff=1;
        i=2;
        do {
            k=i;
            i<<=1;
        } while (i);
        maxran=k;
        srand(*idum);
        *idum=1;
        for (j=1;j<=97;j++) dum=rand();
        for (j=1;j<=97;j++) v[j]=rand();
        y=rand();
    }
    j=1+97.0*y/maxran;
    if (j > 97 || j < 1) quit(-1, "%s: this cannot happen.", rname);
    y=v[j];
    v[j]=rand();
    return y/maxran;
}


double get_bj_prob(float **bj, state **states, int i, char *word, int vocab_size) {
    static char rname[]="get_bj_prob";
    int word_id;

    if (!strcmp(states[i]->label,"end") || !strcmp(states[i]->label,"start")) return(0.0);

    word_id = bow_word2int_no_add(word);

    if ((word_id == -1) || (word_id > vocab_size)) { /* Word does not exist in distribution */
        word_id = bow_word2int_no_add(unk_word);
        if ((word_id == -1) || (word_id > vocab_size)) {
            return(0.0);
        }
    }

    return(bj[i][word_id]);
}


void print_emission_dists(char *emissions_dir, float **bj, int num_states, int vocab_size, state **states) {
    static char rname[]="print_emission_dists";
    FILE *output_file;
    char output_file_path[1000];
    int i;

    for (i = 0; i < num_states; i++) {
        if (!(!strcmp(states[i]->label,"end") ||
            !strcmp(states[i]->label,"start"))) {

            sprintf(output_file_path, "%s/%d.arpa.gz", emissions_dir, i);

            output_file = kopen_wgz(output_file_path);

            print_arpa_unigram(output_file, bj[i], vocab_size);
            fclose(output_file);
        }
    }
}

void print_reestimated_model_to_file(FILE *outfile, state **states, int num_states, char *emissions_dir) {
    static char rname[]="print_reestimated_model_to_file";
    int   i, num_dists;
    ltrans *first_trans, *current_trans;
    multinomial **dists;
    char output_file_path[1000], label[100];

    /* Initialize distribution array - at most there are as many distributions
        as there are states */
    dists = (multinomial **) kmalloc((num_states+1) * sizeof(multinomial *));
    for (i = 0; i < num_states; i++) {
        dists[i] = NULL;
    }

    /* Initialize transitions - don't know how many there will be */
    first_trans = NULL;

    /* Go through model and collect pointers to the states, distributions and transitions */
    for (i = 0; i < num_states; i++) {
        collect_state_info(states[i], num_states, dists, &first_trans);
    }

    /* Print distribution information to output file */
    fprintf(outfile, "%d\n", num_states-2);
    for (i = 0; i < num_states; i++) {

        if (!(!strcmp(states[i]->label,"end") || 
              !strcmp(states[i]->label,"start"))) {
            sprintf(output_file_path, "%s/%d.arpa", emissions_dir, i);
            fprintf(outfile, "%d %s\n", i, output_file_path);
        }
    }

    /* Print state information to output file */
    fprintf(outfile, "%d\n", num_states);
    for (i = 0; i < num_states; i++) {
        if (!strcmp(states[i]->label,"end") || 
            !strcmp(states[i]->label,"start")) {
            sprintf(label, "null");
        }
        else {
            sprintf(label, "%d", i);
        }
        fprintf(outfile, "%d %s %d %s\n", states[i]->id, states[i]->label, states[i]->duration, label);
    }

    /* Print transition information to output file */
    current_trans = first_trans;
    while (current_trans != NULL) {
        fprintf(outfile, "%d %d %.20f\n", current_trans->trans->source->id, current_trans->trans->dest->id, current_trans->trans->prob);
        current_trans = current_trans->next;
    }

    free(dists);
    free_linked_transitions(first_trans);
}

float **set_emission_probs(int num_states, int vocab_size, state **states) {
    static char rname[] = "set_emission_probs";
    float **bj;
    int i, j;

    /* Allocate space for each emission array, one per state */

    /* One array per state */
    bj = (float **) kmalloc(num_states * sizeof(float *));
    for (i = 0; i < num_states; i++) {
        bj[i] = (float *) kmalloc((vocab_size+1) * sizeof(float));
    }

    /* Fill in initial values from existing model */
    for (i = 0; i < num_states; i++) {
        if (states[i] == NULL) continue;
        for (j = 0; j <= vocab_size; j++) {

            if (!strcmp(states[i]->label,"end") ||
                !strcmp(states[i]->label,"start")) {
                bj[i][j] = 0.0;
            }
            else {
                if (j <= states[i]->O->vocab_size) {
                    if (states[i]->O->lprobs[j] > 0) bj[i][j] = 0.0;
                    else bj[i][j] = exp(states[i]->O->lprobs[j]);
                }
                else bj[i][j] = 0.0;
            }
        }
    }
    return(bj);
}





