/* merge.c - routines for merging states of an HMM */

#include "general.h"

state_pairs *compute_all_candidates(state *model, state **states, int num_states, int same_label, int neighbors_only) {
    static char rname[]="compute_all_candidates";
    state_pairs *first_cand, *temp, *prev;
    trans *current_t;
    int i, j;
  
    first_cand = NULL;
    prev = NULL;

    /* Collect candidate pairs from state list - change this
	function if restraints can be placed on the candidate merges */
    for (i = 0; i < num_states; i++) {
        if (!strcmp(states[i]->label, "start") || !strcmp(states[i]->label, "end")) continue;
        for (j = i+1; j < num_states; j++) {

            if (!strcmp(states[j]->label, "start") || !strcmp(states[j]->label,"end")) continue;

            /* Put checks here for valid candidate pairs */
            if (same_label)
                if (!have_same_label(states[i], states[j])) continue;

            if (neighbors_only)
                if (!are_neighbors(states[i], states[j])) continue;

            temp = (state_pairs *) kmalloc(sizeof(state_pairs));
            temp->s1 = states[i];
            temp->s2 = states[j];
            temp->value = 1000001.0;
            temp->next = NULL;

            if (first_cand == NULL) first_cand = temp;
            if (prev != NULL) prev->next = temp;
            prev = temp;
        }
    }
    return(first_cand);
}

state **collect_states(state *initial, int *ptr_num_states) {
    state **states;
    int i, temp_count, new_num_states, num_states;

    num_states = *ptr_num_states;
    new_num_states = num_states;

    if (DEBUG) fprintf(stderr, "Received %d states\n", num_states);

    states = (state **) kmalloc((num_states+1) * sizeof(state *));
    for (i = 0; i < num_states; i++) {
        states[i] = NULL;
    }
    collect_children_of(initial, num_states, states);

    for (i = 0; i < num_states; i++) {
        if (states[i] == NULL) {
            new_num_states = i; /* Include i because of 0th state */
            fprintf(stderr, "New num states = %d\n", i);
            break;
        }
    }

    if (new_num_states != num_states) {
        fprintf(stderr, "Changed number of states to %d\n", new_num_states);
    }
    num_states = new_num_states;

    temp_count = set_seen_to_zero(states, num_states);
    if (DEBUG) fprintf(stderr, "1: Set %d states to unseen\n", temp_count);

    *ptr_num_states = num_states;

    return(states);
}

void collect_children_of(state *current_s, int num_states, state **states) {
    static char rname[]="collect_children_of";
    trans *current_t;
    int i, added;

    for (i = 0; i < num_states; i++) {
        if (states[i] == NULL) {
            states[i] = current_s;
            added = 1;
            break;
        }
        if (states[i]->id == current_s->id) {
            return;
        }
    }

    if (!added) quit(-1, "%s: found more states than num_states (%d)\n", rname, num_states);

    current_t = current_s->out;
    while (current_t != NULL) {
        collect_children_of(current_t->dest, num_states, states);
        current_t = current_t->next_source;
    }
}

void update_candidates(state_pairs **p_candidates, state_pairs *max_cand, state **states, int num_states, state *new_state, int same_label, int neighbors_only) {
    static char rname[]="update_candidates";
    state_pairs *candidate, *prev, *last, *temp, *next;
    state *max_cand_s1, *max_cand_s2;
    trans *current_t;
    int i, j, last_pointer;
    int num_removed, num_added;

    max_cand_s1 = max_cand->s1;
    max_cand_s2 = max_cand->s2;

    last = NULL;
    prev = NULL;
    next = NULL;

    num_removed = 0;
    num_added = 0;

    /* Remove all candidates from list that involve a state from max_cand */
    candidate = *p_candidates;
    while (candidate != NULL) {
        next = candidate->next;
        if (candidate->s1 == max_cand_s1 || 
            candidate->s1 == max_cand_s2 || 
            candidate->s2 == max_cand_s1 || 
            candidate->s2 == max_cand_s2) {

            if (prev == NULL) *p_candidates = candidate->next;
            else prev->next = candidate->next;
            free(candidate);
            num_removed++;
        }
        else prev = candidate;

        candidate = next;
    }
    last = prev;

    /* Reset values of candidate pairs involving parents of the new state 
        (these were also the parents of the states from max_cand */
    candidate = *p_candidates;
    while (candidate != NULL) {

        /* Check to see if the parents of the new state match this candidate's states: */
        current_t = new_state->in;
        while (current_t != NULL) {

            if (candidate->s1 == current_t->source || candidate->s2 == current_t->source) {
                candidate->value = 1000001.0;
                if (DEBUG) fprintf(stderr, "Reset candidate pair [%d (%s), %d (%s)]\n", candidate->s1->id, candidate->s1->label, candidate->s2->id, candidate->s2->label);
                break;
            }
            current_t = current_t->next_dest;
        }
        candidate = candidate->next;
    }

    /* Add candidate pairs with new state */
    for (i = 0; i < num_states; i++) {
        if (states[i] == max_cand_s1 || states[i] == max_cand_s2) {
            if (DEBUG) fprintf(stderr, "Set space %d in states array to null (for state %d)\n", i, states[i]->id);
            states[i] = NULL;
        }
        else if (states[i] != NULL) {
            if (!strcmp(states[i]->label, "start") || !strcmp(states[i]->label, "end")) continue;

            /* Put checks here for valid candidate pairs */
            if (same_label)
               if (!have_same_label(states[i], new_state)) continue;

            if (neighbors_only)
                if (!are_neighbors(states[i], new_state)) continue;

            temp = (state_pairs *) kmalloc(sizeof(state_pairs));
            if (states[i]->id < new_state->id) {
                temp->s1 = states[i];
                temp->s2 = new_state;
            }
            else {
                temp->s1 = new_state;
                temp->s2 = states[i];
            }
            temp->value = 1000001.0;
            temp->next = NULL;

            if (last == NULL) {
                last = temp;
                *p_candidates = temp;
            }
            else last->next = temp;
            last = temp;
            num_added++;
        }
    }

    /* Collapse state array and add new state */
    j = 0;
    last_pointer = -1;
    for (i = 0; i < num_states; i++) {
        if (DEBUG) fprintf(stderr, "i = %d, j = %d, last = %d\n", i,j,last_pointer);
        if (j == num_states) break;
        while (states[j] == NULL) {
            j++;
            if (j >= num_states) break;
        }
        if (j >= num_states) break;

        if (DEBUG) fprintf(stderr, "   Assigning states[%d] = states[%d] (state id = %d)\n", i, j, states[j]->id);
        if (i != j) states[i] = states[j];   
        last_pointer = i;
        j++;
    }

    if (last_pointer+1 >= num_states) quit(-1, "%s: not enough space in state array for new state...\n",rname);

    states[last_pointer+1] = new_state;
    if (DEBUG) fprintf(stderr, "states[%d] = new (state id = %d)\n", last_pointer+1, new_state->id);

    /* Fill up remaining spaces in states with null pointers */
    for (i = last_pointer+2; i < num_states; i++) {
        states[i] = NULL;
        if (DEBUG) fprintf(stderr, "   Assigning states[%d] = NULL\n", i);
    }

    /* Free old states */
    free_state(max_cand_s1);
    free_state(max_cand_s2);

    if (DEBUG) fprintf(stderr, "Updated candidates: removed %d pairs, added %d pairs, num_states = %d\n", num_removed, num_added, num_states);

}

state *merge_states(state *s1, state *s2, trans **p_trans_list, double trans_uni_alpha) {
    static char rname[]="merge_states";
    state *new_state;
    trans trans_list;
    int id, duration;
    char *label;
    multinomial *O;

    if (DEBUG) fprintf(stderr, "Merging states %d and %d\n", s1->id, s2->id);

    if (s1->id < s2->id) id = s1->id;
    else id = s2->id;

    if (!strcmp(s1->label, s2->label)) {
        label = strdup(s1->label);
    }
    else {
        label = strdup("mixed");
    }

    if (s1->duration > s2->duration) duration = s1->duration;
    else duration = s2->duration;

    O = combine_emissions(s1->O, s2->O, id);

    new_state = create_state(id, label, NULL, duration);
    new_state->O = O;
    
    combine_trans(new_state, s1, s2, p_trans_list, trans_uni_alpha);

    free(label);

    return(new_state);
}

void unmerge_states(state *new_state, trans *trans_list) {
    static char rname[]="unmerge_states";
    trans *current_t, *next_t;

    if (DEBUG) fprintf(stderr, "Unmerging state %d\n", new_state->id);

    free_state(new_state);

    current_t = trans_list;
    while (current_t != NULL) {
        add_trans(current_t->source, current_t->dest, current_t->count, current_t->alpha, current_t->prob);
        next_t = current_t->next_source;
        free(current_t);
        current_t = next_t;
    }
    if (current_t != NULL) free(current_t);
}

void collapse_same_tags(state *s1, int *p_num_states, double trans_uni_alpha) {
    static char rname[]="collapse_same_tags";
    trans *current_t, *trans_list, *next_t;
    state *new, *s2;
    int num_dest, num_source;

    if (s1 == NULL) return;
    if (s1->out == NULL) return;
    if (s1->seen == 1) return;

    num_dest = num_different_dest_states(s1);

    /* If current state has only one child */
    if (num_dest == 1) {
        current_t = s1->out;
        while (current_t != NULL && current_t->dest->id == s1->id) {
            current_t = current_t->next_source; 
        }
        if (current_t == NULL) return;
        s2 = current_t->dest;
        if (s1->id == s2->id) return;
        num_source = num_different_source_states(s2);

        /* If label of current state and next state are equal */
        if (num_source == 1 && !strcmp(s1->label, s2->label)) {

            trans_list = NULL;
            new = merge_states(s1, s2, &trans_list, trans_uni_alpha);
            free_trans_list(trans_list);
            free_state(s2);
            free_state(s1);
            (*p_num_states)--;

            if (DEBUG) fprintf(stderr, "New state id = %d, label = %s\n", new->id, new->label);

            collapse_same_tags(new, p_num_states, trans_uni_alpha);
        }
        else {
            s1->seen = 1;
            collapse_same_tags(s2, p_num_states, trans_uni_alpha);
        }
    }
    else {
        s1->seen = 1;
        current_t = s1->out;
        while (current_t != NULL) {
            next_t = current_t->next_source;
            if (s1->id != current_t->dest->id) {
                collapse_same_tags(current_t->dest, p_num_states, trans_uni_alpha);
            }
            current_t = next_t;
        }
    }
}

int num_different_dest_states(state *s1) {
    static char rname[]="num_different_dest_states";
    int num;
    trans *current_t;
  
    num = 0;
    current_t = s1->out;
    while (current_t != NULL) {
        if (current_t->dest->id != s1->id) num++;
        current_t = current_t->next_source;
    }
    return(num);
}

int num_different_source_states(state *s1) {
    static char rname[]="num_different_source_states";
    int num;
    trans *current_t;

    num = 0;
    current_t = s1->in;
    while (current_t != NULL) {
        if (current_t->source->id != s1->id) num++;
        current_t = current_t->next_dest;
    }
    return(num);
}

void free_trans_list(trans *trans_list) {
    static char rname[]="free_trans_list";
    trans *current_t, *next_t;

    current_t = trans_list;
    while (current_t != NULL) {
        next_t = current_t->next_source;
        free(current_t);
        current_t = next_t;
    }
}

void free_candidates(state_pairs *candidates) {
    static char rname[]="free_candidates";
    state_pairs *candidate, *next_cand;

    candidate = candidates;
    while (candidate != NULL) {
        next_cand = candidate->next;
        free(candidate);
        candidate = next_cand;
    }
}

double compute_candidate_contribution(state *s1, state *s2, int num_states, float prior_weight, int use_prior, int mode) {
    static char rname[]="compute_candidate_contribution";
    double lposterior, lprior, llikelihood;
    trans *current_t;
    state **already_computed;
    int i, max, found;

    lprior = 0;
    llikelihood = 0;

    already_computed = (state **) kmalloc((num_states+1)*sizeof(state *));
    for (i = 0; i < num_states; i++) {
        already_computed[i] = NULL;
    }

    /* Compute the prior contribution and likelihood contribution for each state */
    if (use_prior > 0) lprior = s1->lprior;
    if (use_prior > 0) lprior += s2->lprior;

    llikelihood = s1->llikelihood;
    llikelihood += s2->llikelihood;

    already_computed[0] = s1;
    already_computed[1] = s2;
    max = 1;

    /* Compute the prior contribution of each state that transitions into the candidate states */
    current_t = s1->in;
    while (current_t != NULL) {

        /* See if source state has already been added into calculation */
        found = 0;
        for (i = 0; i <= max; i++) {
            if (already_computed[i] == current_t->source) {
                found = 1;
                break;
            }
        }
        if (!found) {
            if (use_prior > 0) lprior += current_t->source->lprior;
            llikelihood += current_t->source->llikelihood;
            max++;
            if (max > num_states) quit(-1, "%s: more states to add than space: max = %d, num_states = %d\n", rname, max, num_states);
            already_computed[max] = current_t->source;
        }
        current_t = current_t->next_dest;
    }

    current_t = s2->in;
    while (current_t != NULL) {

        /* See if source state has already been added into calculation */
        found = 0;
        for (i = 0; i <= max; i++) {
            if (already_computed[i] == current_t->source) {
                found = 1;
                break;
            } 
        }
        if (!found) {
            if (use_prior > 0) lprior += current_t->source->lprior;
            llikelihood += current_t->source->llikelihood;
            max++;
            if (max > num_states) quit(-1, "%s: more states to add than space: max = %d, num_states = %d\n", rname, max, num_states);
            already_computed[max] = current_t->source;
        }
        current_t = current_t->next_dest;
    }

    free(already_computed);

    lposterior = llikelihood + (prior_weight * lprior);
    if (DEBUG) fprintf(stderr, "Candidate prior = %f, likelihood = %f, posterior = %f\n", lprior, llikelihood, lposterior);
    return(lposterior);
}

double compute_new_state_contribution(state *new_state, int num_states, float prior_weight, int use_prior, int mode) {
    static char rname[]="compute_new_state_contribution";
    double lposterior, lprior, llikelihood;
    trans *current_t;

    lprior = 0;
    llikelihood = 0;

    /* Compute the prior contribution and likelihood contribution for the new state */
    if (use_prior > 0) lprior = new_state->lprior;
    llikelihood = new_state->llikelihood;

    /* Compute the prior contribution of each state that transitions into the candidate states */
    current_t = new_state->in;
    while (current_t != NULL) {
        if (current_t->source != new_state) {
            if (use_prior > 0) lprior += current_t->source->lprior;
            llikelihood += current_t->source->llikelihood;
        }
        current_t = current_t->next_dest;
    }

    lposterior = llikelihood + (prior_weight * lprior);
    if (DEBUG) fprintf(stderr, "New state prior = %f, likelihood = %f, posterior = %f\n", lprior, llikelihood, lposterior);
    return(lposterior);
}

double compute_state_prior(state *s1, int num_states, int mode, int use_prior, int narrow_emis) {
    static char rname[]="compute_state_prior";
    double str_lprior, par_lprior, total_lprior;

    if (!strcmp(s1->label, "end")) return(0.0);

    /* Recap of mode and prior settings:
	mode = 1: map
               2: mean
               3: ml
               4: structure

	use_prior = 0: no prior
                    1: normal prior
                    2: structure prior only (no parameter prior)
    */

    if (use_prior > 0) str_lprior = compute_structure_prior(s1, num_states);
    else str_lprior = 0.0;

    if (mode < 4 && use_prior == 1) { 
        par_lprior = compute_parameter_prior(s1, num_states, narrow_emis, mode);
    }
    else par_lprior = 0.0;

    if (DEBUG) fprintf(stderr, "%s: state  = %d (%s), structure lprior = %f, parameter lprior = %f\n", rname, s1->id, s1->label, str_lprior, par_lprior);

    total_lprior = str_lprior + par_lprior;

    return(total_lprior);
}

double compute_state_likelihood(state *s1, int num_states, int mode, int narrow_emis) {
    static char rname[]="compute_state_likelihood";
    double llikelihood;

    if (!strcmp(s1->label, "end")) return(0.0);

    if (mode < 4) 
        llikelihood = compute_vit_likelihood(s1, mode);
    else /* mode == 4 */
        llikelihood = compute_struct_state_likelihood(s1, num_states, narrow_emis);

    if (DEBUG) fprintf(stderr, "%s: log likelihood for state %d = %f\n", rname, s1->id, llikelihood);

    return(llikelihood);
}

double compute_structure_prior(state *s1, int num_states) {
    static char rname[]="compute_structure_prior";
    double lprior;
    int num_trans, num_emissions, vocab_size;
    trans *current_t;

    lprior = mdl_prior(s1, num_states);

    /* lprior = narrow_structure_prior(s1, num_states); */

    return(lprior);
}

double narrow_structure_prior(state *s1, int num_states) {
    static char rname[]="narrow_structure_prior";
    double trans_lprior, emis_lprior, total_lprior, prob;
    int num_trans;
    trans *current_t;

    /* Make sure start state is ok */

    /* Transition component: count the number of out transitions */
    num_trans = 0;
    current_t = s1->out;
    while (current_t != NULL) {
        num_trans++;
        current_t = current_t->next_source;
    }
    prob = (double) num_trans / num_states;
    trans_lprior = num_trans * log(prob) + (num_states - num_trans) * log(1 - prob);
    if (DEBUG) fprintf(stderr, "Transition structure prior: prob = %f, num_trans = %d, narrow transition lprior = %f\n", prob, num_trans, trans_lprior);

    /* Emission component */
    if (!strcmp(s1->label, "start") || !strcmp(s1->label, "end")) {
        emis_lprior = 1.0;
    }
    else {
        prob = ((double) s1->O->num_types) / (s1->O->vocab_size + 1);
        emis_lprior = s1->O->num_types * log(prob) + (s1->O->vocab_size + 1 - s1->O->num_types) * log(1 - prob);
        if (DEBUG) fprintf(stderr, "Emission structure prior: prob = %f, num types = %d, emission lprior = %f\n", prob, s1->O->num_types, emis_lprior);
    }

    total_lprior = trans_lprior + emis_lprior;
    if (DEBUG) fprintf(stderr, "Trans lprior = %f, emis lprior = %f, total lprior = %f\n", trans_lprior, emis_lprior, total_lprior);

    return(total_lprior);
}

double mdl_prior(state *s1, int num_states) {
    static char rname[]="mdl_prior";
    double lprior;
    int num_trans, num_emissions, vocab_size;
    trans *current_t;

    vocab_size = s1->O->vocab_size;

    /* if vocab_size == -1, then there is no emission distribution (like for the start state).
	Set vocab_size so that this component of the prior makes no contribution */
    if (vocab_size == -1) vocab_size = 0;

    /* Count the number of out transitions */
    num_trans = 0;
    current_t = s1->out;
    while (current_t != NULL) {
        num_trans++;
        current_t = current_t->next_source;
    }

    /* Count the number of emissions */
    num_emissions = s1->O->num_types;

    if (DEBUG) fprintf(stderr, "num_trans = %d, num_state = %d, num_emissions = %d, vocab_size = %d\n", num_trans, num_states, num_emissions, vocab_size+1);

    lprior = -num_trans * log(num_states+1) + -num_emissions * log(vocab_size+1);

    if (DEBUG) fprintf(stderr, "mdl lprior = %f\n", lprior);

    return(lprior);
}

double compute_parameter_prior(state *s1, int num_states, int narrow_emis, int mode) {
    static char rname[]="compute_parameter_prior";
    int i, j, num_trans, num_emissions, vocab_size;
    trans *current_t;
    double lprior, trans_beta, emissions_beta, lprob, trans_lprior, emis_lprior;
    double *t_alphas, *e_alphas, alpha;
    tc *temp;

    /* Make sure start state is ok!! */

    trans_lprior = 0;
    emis_lprior = 0;
    lprior = 0;

    /* Transition prior: */
    t_alphas = (double *) kmalloc((num_states+1) * sizeof(double));
    num_trans = 0;
    current_t = s1->out;
    while (current_t != NULL) {
        alpha = calc_trans_alpha(s1, current_t->alpha);
        trans_lprior += (alpha - 1) * log(current_t->prob);
        t_alphas[num_trans] = alpha;
        num_trans++;
        current_t = current_t->next_source;
    }
    if (num_trans > 0) {
        trans_beta = log_beta(t_alphas, num_trans);
        if (DEBUG) fprintf(stderr, "Log of transition beta = %f\n", trans_beta);
        if (DEBUG) fprintf(stderr, "transition lprior = %f, log beta value = %f, total = %f\n", trans_lprior, trans_beta, trans_lprior - trans_beta);
        trans_lprior -= trans_beta;
    }
    free(t_alphas);

    /* Emission prior: */
    vocab_size = s1->O->vocab_size;
    num_emissions = 0;
    if (s1->O != NULL && s1->O->num_types != 0) { 
        if (s1->O->num_tokens == -1) /* distribution holds lprobs */
            quit(-1, "%s: expecting a distribution of counts, not lprobs - exiting...\n", rname);

        if (s1->O->vocab_size > -1) {

            if (narrow_emis) { /* Narrow emissions prior */
                e_alphas = (double *) kmalloc((s1->O->num_types+1) * sizeof(double));

                /* Narrow prior */
                temp = s1->O->counts;
                while (temp != NULL) {
                    alpha = calc_emis_alpha(s1->O, temp->id, s1->O->prior_weight_adjustment);
                    lprob = calc_lprob(s1->O, temp->id, mode);
                    emis_lprior += (alpha - 1) * lprob;
                    if (num_emissions > s1->O->num_types) quit(-1, "%s: not enough room in e_alphas array...\n", rname);
                    e_alphas[num_emissions] = alpha;
                    num_emissions++;
                    temp = temp->next;
                }
                if (num_emissions == 0) quit(-1, "%s: why are there no emissions??\n", rname);
                emissions_beta = log_beta(e_alphas, num_emissions);
                free(e_alphas);
            }
            else { /* Broad emissions prior */
                for (i = 0; i <= s1->O->vocab_size; i++) {
                    alpha = calc_emis_alpha(s1->O, i, s1->O->prior_weight_adjustment);
                    lprob = calc_lprob(s1->O, i, mode);
                    if (lprob < -98.0) continue;
                    emis_lprior += (alpha - 1) * lprob;
                    num_emissions++;
                }
                if (num_emissions == 0) quit(-1, "%s: why are there no emissions??\n", rname);
                if (s1->O->prior != NULL) 
                    emissions_beta = s1->O->prior->log_beta; /* get pre-calculated emissions beta */
                else {
                    e_alphas = (double *) kmalloc((s1->O->vocab_size+1) * sizeof(double)); 
                    for (j = 0; j <= vocab_size; j++) {
                        e_alphas[j] = s1->O->uni_alpha;
                    }
                    emissions_beta = log_beta(e_alphas, vocab_size+1);
                    free(e_alphas);
                }
            }
            if (DEBUG) fprintf(stderr, "Log of emissions beta = %f\n", emissions_beta); 
            if (DEBUG) fprintf(stderr, "emission lprior = %f, log beta value = %f, total = %f\n", emis_lprior, emissions_beta, emis_lprior - emissions_beta);

            emis_lprior -= emissions_beta;
        }
    }

    lprior = trans_lprior + emis_lprior;
    if (DEBUG) fprintf(stderr, "lprior = %f, trans = %f, emis = %f\n", lprior, trans_lprior, emis_lprior);

    return(lprior);
}

double compute_struct_state_likelihood(state *s1, int num_states, int narrow_emis) {
    static char rname[]="compute_struct_state_likelihood";
    double *num_t_alphas, *num_e_alphas, *denom_t_alphas, *denom_e_alphas, alpha;
    double lprob, num_trans_beta, num_emissions_beta, denom_trans_beta, denom_emissions_beta;
    trans *current_t;
    int i, j, num_trans, num_emissions, vocab_size;
    tc *temp;

    lprob = 0.0;
    num_trans_beta = 0.0;
    denom_trans_beta = 0.0;
    num_emissions_beta = 0.0;
    denom_emissions_beta = 0.0;

    /* Collect values of transition parameters */
    num_t_alphas = (double *) kmalloc((num_states+1) * sizeof(double));
    denom_t_alphas = (double *) kmalloc((num_states+1) * sizeof(double));

    num_trans = 0;
    current_t = s1->out;
    while (current_t != NULL) {
        alpha = calc_trans_alpha(s1, current_t->alpha);
        denom_t_alphas[num_trans] = alpha;
        num_t_alphas[num_trans] = alpha + current_t->count;
        num_trans++;
        current_t = current_t->next_source;
    }

    if (num_trans > 0) {
        num_trans_beta = log_beta(num_t_alphas, num_trans);
        denom_trans_beta = log_beta(denom_t_alphas, num_trans);
    }
    free(num_t_alphas);
    free(denom_t_alphas);

    if (DEBUG) fprintf(stderr, "%s: num_trans_beta = %f, denom_trans_beta = %f\n", rname, num_trans_beta, denom_trans_beta);

    /* Collect values of emission parameters */
    num_emissions = 0;
    if (s1->O != NULL && s1->O->num_types != 0) {

        vocab_size = s1->O->vocab_size;
        num_e_alphas = (double *) kmalloc((vocab_size+1) * sizeof(double));
        denom_e_alphas = (double *) kmalloc((vocab_size+1) * sizeof(double));

        if (s1->O->num_tokens == -1) /* distribution holds lprobs */
            quit(-1, "%s: expecting a distribution of counts, not lprobs - exiting...\n",rname);

        if (narrow_emis) { /* Narrow emission priors */
            temp = s1->O->counts;
            while (temp != NULL) {
                alpha = calc_emis_alpha(s1->O, temp->id, s1->O->prior_weight_adjustment);
                denom_e_alphas[num_emissions] = alpha;
                num_e_alphas[num_emissions] = alpha + temp->count;
                num_emissions++;
                temp = temp->next;
            }

            if (num_emissions > 0) {
                num_emissions_beta = log_beta(num_e_alphas, num_emissions);
                denom_emissions_beta = log_beta(denom_e_alphas, num_emissions);
            }
        }
        else { /* Broad emissions priors */
            for (j = 0; j <= vocab_size; j++) {
                alpha = calc_emis_alpha(s1->O, j, s1->O->prior_weight_adjustment);
                denom_e_alphas[j] = alpha;
                num_e_alphas[j] = alpha;
            }
            temp = s1->O->counts;
            while (temp != NULL) {

                if (temp->id < 0 || temp->id > vocab_size) 
                    quit(-1, "%s: error - adding count of %d to word id %d for state %d (%s); vocab size = %d\n", rname, temp->count, temp->id, s1->id, s1->label, vocab_size);

                num_e_alphas[temp->id] += temp->count;
                temp = temp->next;
            }

            num_emissions_beta = log_beta(num_e_alphas, vocab_size+1);
            if (s1->O->prior != NULL)
                denom_emissions_beta = s1->O->prior->log_beta; /* get pre-calculated emissions beta */
            else
                denom_emissions_beta = log_beta(denom_e_alphas, vocab_size+1);
        }

        free(num_e_alphas);
        free(denom_e_alphas);
    }

    lprob = num_trans_beta + num_emissions_beta - denom_trans_beta - denom_emissions_beta;

    if (DEBUG) fprintf(stderr, "%s: lprob = %f, num_trans_beta = %f, num_emissions_beta = %f, denom_trans_beta = %f, denom_emissions_beta = %f\n", rname, lprob, num_trans_beta, num_emissions_beta, denom_trans_beta, denom_emissions_beta);

    return(lprob);
}


double compute_vit_likelihood(state *s1, int mode) {
    static char rname[]="compute_vit_likelihood";
    double lprob;
    trans *current_t;
    int i;
    tc *temp;

    if (DEBUG) fprintf(stderr, "%s: begin\n", rname);

    lprob = 0.0;
    /* Calculate transition contributions */
    current_t = s1->out;
    while (current_t != NULL) {
        lprob += current_t->count * log(current_t->prob);
        current_t = current_t->next_source;
    }

    if (DEBUG) fprintf(stderr, "%s: moving to emissions...\n", rname);

    /* Calculate emission contributions */
    temp = s1->O->counts;
    while (temp != NULL) {
        lprob += temp->count * calc_lprob(s1->O, temp->id, mode);
        temp = temp->next;
    }

    if (DEBUG) fprintf(stderr, "%s: end\n", rname);

    return(lprob);
}


double log_beta(double *values, int num_values) {
    static char rname[]="log_beta";
    int i;
    double numerator, total;

    numerator = 0.0;
    total = 0;
    for (i = 0; i < num_values; i++) {
        numerator += gammln(values[i]);
        total += values[i];
    }
    numerator -= gammln(total);

    return(numerator);

}

/* From Numerical Recipes in C */
double gammln(double xx) {
    static char rname[]="gammln";
    double x,tmp,ser;
    static double cof[6]={76.18009173,-86.50532033,24.01409822,
                -1.231739516,0.120858003e-2,-0.536382e-5};
    int j;

    x = xx - 1.0;
    tmp = x + 5.5;
    tmp -= (x+0.5) * log(tmp);
    ser = 1.0;
    for (j = 0; j <= 5; j++) {
        x += 1.0;
        ser += cof[j]/x;
    }
    return -tmp + log(2.50662827465 * ser);
}

int have_same_label(state *s1, state *s2) {
    static char rname[]="have_same_label";

    if (!strcmp(s1->label, s2->label)) {
        return(1);
    }
    else {
        return(0);
    }
}

int are_neighbors(state *s1, state *s2) {
    static char rname[]="are_neighbors";
    trans *current_t;

    current_t = s1->out;
    while (current_t != NULL) {
        if (current_t->dest->id == s2->id ) return(1);
        current_t = current_t->next_source;
    }
    current_t = s2->out;
    while (current_t != NULL) {
        if (current_t->dest->id == s1->id ) return(1);
        current_t = current_t->next_source;
    }
    return(0);
}

int set_seen_to_zero(state **states, int num_states) {
    static char rname[]="set_seen_to_zero";
    int i, count;

    count = 0;
    for (i = 0; i < num_states; i++) {
        if (states[i] != NULL) {
            states[i]->seen = 0;
            count++;
        }
    }
    return(count);
}

state **collapse_adjacent_states(state **states, int *ptr_num_states, state *initial, double trans_uni_alpha) {
    int orig_num_states;
    state **new_states;

    orig_num_states = *ptr_num_states;
    fprintf(stderr, "Collapsing adjacent states with the same label...\n");
    collapse_same_tags(initial, ptr_num_states, trans_uni_alpha);
    fprintf(stderr, "%d states were collapsed to %d states\n", orig_num_states, *ptr_num_states);

    /* Readjust state array */
    free(states);
    new_states = collect_states(initial, ptr_num_states);

    return(new_states);
}

state **collapse_V_states(state **states, int *ptr_num_states, state *initial, state *end, double trans_uni_alpha) {
    int orig_num_states;
    state **new_states;

    orig_num_states = *ptr_num_states;
    fprintf(stderr, "Collapsing V-states with the same label (forward and backward) ...\n");

    fprintf(stderr, "Forward V collapse...\n");
    collapse_V_tags_forward(initial, ptr_num_states, trans_uni_alpha);
    fprintf(stderr, "%d states were collapsed to %d states\n", orig_num_states, *ptr_num_states);
    orig_num_states = *ptr_num_states;

    /* Readjust state array */
    free(states);
    states = collect_states(initial, ptr_num_states);

    fprintf(stderr, "Backward V collapse...\n");
    collapse_V_tags_backward(end, ptr_num_states, trans_uni_alpha);

    fprintf(stderr, "%d states were collapsed to %d states\n", orig_num_states, *ptr_num_states);

    /* Create state array */
    free(states);
    new_states = collect_states(initial, ptr_num_states);

    return(new_states);
}






