/* bmm.c
# ---------------------------------------------------------------------------------
# Bayesian Model Merging code
# ----------------------------------------------------------------------------------

See documentation for code in /homes/kseymore/ra/hmm/doc/hmm_tools.txt

*/

#include "general.h"

const char *argp_program_version = "bmm_1.0";

main (int argc, char **argv) {
    static char rname[]="bmm";
    FILE	*string_file;
    int		iarg, num_states, num_strings, i, forward_details;
    int		need_help, print_hmm, iteration, vocab_size, lookahead_threshold;
    int		read_count, read_label, initial_adj_collapse, mode, lookahead_count;
    int         map, mean, ml, structure, smooth, evaluate_model, closed_vocab;
    int		same_label, neighbors_only, num_cands, eff_sample, num_samples, same_label_at_first;
    int		initial_collapse_exit, initial_V_collapse, initial_model_exit;
    int		change, print_start, read_id, no_prior, use_prior, structure_prior, cand_to_compute;
    int		num_prior_dists, step_size, cand_count, num_added, more_strings, max_state_label;
    int         evaluate_start, evaluate_step, incremental, orig_num_to_add, num_to_add, narrow_emission_prior;
    int         last_good_num_states, last_good_iteration;
    double	fprob, lMAP, value, beta, existing_states_contribution, new_state_contribution;
    double      obs_uni_alpha, trans_uni_alpha;
    double	prior_weight, dir_prior_weight;
    char 	*string_file_path, *output_dir, *prior_dist_file_path, *vocab_file_path, *emissions_dir;
    char        *obs_file_path;
    shead	**strings;
    state	*initial, *end, *new_state, **states;
    state_pairs *candidates, *candidate, *max_cand;
    path_head	*vit_path;
    trans	*trans_to_keep;
    count_dist  **priors;

    /* Process command line arguments */
    need_help = read_com_noarg(&argc, argv, "-help");
    if (argc == 1 || need_help) {

        /* Display help message */
        fprintf(stderr,"bmm : Given a set of training observations, find the hidden Markov model\n      that maximizes the probability of the model given the observations.\n");
        fprintf(stderr,"Usage: bmm -data <data_file>\n");
        fprintf(stderr,"           -map | -mean | -ml | -structure\n");
        fprintf(stderr,"           -vocab <vocab_file>\n");
        fprintf(stderr,"           -out <output_dir>\n");
        fprintf(stderr,"           [ -closed_vocab ]\n");
        fprintf(stderr,"           [ -incremental <orig_num_strings_to_add> <num_strings_to_add>\n");
        fprintf(stderr,"           [ -lookahead <num_steps>\n");
        fprintf(stderr,"           [ -priors <dist_file>  (format of <word> <count>)]\n");
        fprintf(stderr,"           [ -print_model <start_iteration> <step> ]\n");
        fprintf(stderr,"           [ -evaluate_model <observation_file> <start_iteration> <step> ]\n");
        fprintf(stderr,"           [ -print_dist <directory> ]\n");
        fprintf(stderr,"           [ -smooth_dists <mode> (absolute discounting = 1, linear interpolation = 2) ]\n");
        fprintf(stderr,"           [ -read_count ]\n");
        fprintf(stderr,"           [ -read_label ]\n");
        fprintf(stderr,"           [ -read_id ]\n");
        fprintf(stderr,"           [ -initial_adj_collapse ]\n");
        fprintf(stderr,"           [ -initial_V_collapse ]\n");
        fprintf(stderr,"           [ -exit_after_initial_model ]\n");
        fprintf(stderr,"           [ -exit_after_initial_collapse ]\n");
        fprintf(stderr,"           [ -same_label ]\n");
        fprintf(stderr,"           [ -same_label_at_first ]\n");
        fprintf(stderr,"           [ -neighbors_only ]\n");
        fprintf(stderr,"           [ -no_prior | -structure_prior ]\n");
        fprintf(stderr,"           [ -narrow_emission_prior  (broad is default) ]\n");
        fprintf(stderr,"           [ -pw <prior_weight (1.0)> | -eff_sample_size <count> ]\n");
        fprintf(stderr,"           [ -fixed_dirichlet_prior_weight <weight> ]\n");
        fprintf(stderr,"           [ -trans_uniform_alpha <count (1.0)> ]\n");
        fprintf(stderr,"           [ -obs_uniform_alpha <count (1.0)> ]\n");
        fprintf(stderr,"\nDocumentation available at /homes/kseymore/ra/hmm/doc/hmm_tools.txt\n\n");
        exit(1);
    }
    else {
        /* Print out command line arguments */
        for (i = 0; i < argc; i++) {
            fprintf(stderr, "%s ", argv[i]);
        }
        fprintf(stderr, "\n");
    }

    string_file_path = read_com_string(&argc, argv, "-data");
    map = read_com_noarg(&argc, argv, "-map");
    mean = read_com_noarg(&argc, argv, "-mean");
    ml = read_com_noarg(&argc, argv, "-ml");
    structure = read_com_noarg(&argc, argv, "-structure");
    output_dir = read_com_string(&argc, argv, "-out");
    incremental = read_com_two_int(&argc, argv, "-incremental", &orig_num_to_add, &num_to_add);
    prior_dist_file_path = read_com_string(&argc, argv, "-priors");
    vocab_file_path = read_com_string(&argc, argv, "-vocab");
    closed_vocab = read_com_noarg(&argc, argv, "-closed_vocab");
    print_hmm = read_com_two_int(&argc, argv, "-print_model", &print_start, &step_size);
    obs_file_path = read_com_one_string_two_int(&argc, argv, "-evaluate_model", &evaluate_start, &evaluate_step);
    emissions_dir = read_com_string(&argc, argv, "-print_dist");
    read_count = read_com_noarg(&argc, argv, "-read_count");
    read_label = read_com_noarg(&argc, argv, "-read_label");
    read_id = read_com_noarg(&argc, argv, "-read_id");
    same_label = read_com_noarg(&argc, argv, "-same_label");
    same_label_at_first = read_com_noarg(&argc, argv, "-same_label_at_first");
    neighbors_only = read_com_noarg(&argc, argv, "-neighbors_only");
    no_prior = read_com_noarg(&argc, argv, "-no_prior");
    structure_prior = read_com_noarg(&argc, argv, "-structure_prior");
    narrow_emission_prior = read_com_noarg(&argc, argv, "-narrow_emission_prior");
    initial_adj_collapse = read_com_noarg(&argc, argv, "-initial_adj_collapse");
    initial_V_collapse = read_com_noarg(&argc, argv, "-initial_V_collapse");
    initial_collapse_exit = read_com_noarg(&argc, argv, "-exit_after_initial_collapse");
    initial_model_exit = read_com_noarg(&argc, argv, "-exit_after_initial_model");
    if (!read_com_double(&argc, argv, "-pw", &prior_weight)) prior_weight = 1.0;
    if (!read_com_int(&argc, argv, "-eff_sample_size", &eff_sample)) eff_sample = 0;
    if (!read_com_double(&argc, argv, "-trans_uniform_alpha", &trans_uni_alpha)) trans_uni_alpha = 1.0;
    if (!read_com_double(&argc, argv, "-obs_uniform_alpha", &obs_uni_alpha)) obs_uni_alpha = 1.0;
    if (!read_com_int(&argc, argv, "-lookahead", &lookahead_threshold)) lookahead_threshold = 0;
    if (!read_com_double(&argc, argv, "-fixed_dirichlet_prior_weight", &dir_prior_weight)) dir_prior_weight = 0.0;;
    if (!read_com_int(&argc, argv, "-smooth_dists", &smooth)) smooth = 0;

    check_extra_args(&argc, argv);

    fprintf(stderr, "-----\n");

    /* Check for consistency */
    if (string_file_path == NULL) quit(-1, "%s: must specify a data file with -data\n",rname);
    else fprintf(stderr, "String file path = %s\n", string_file_path);

    /* Determine how model prior should be calculated */
    mode = map + mean + ml + structure;
    if (mode != 1) quit(-1, "%s: must select one (and only one) of -map, -mean, -ml and -structure\n", rname);
    else {
        if (map) {
            mode = 1;
            fprintf(stderr, "Model parameters will be set to their maximum a posteriori estimates\n");
        }
        else if (mean) {
            mode = 2;
            fprintf(stderr, "Model parameters will be set to their mean posterior estimates\n");
        }
        else if (ml) {
            mode = 3;
            fprintf(stderr, "Model parameters will be set to their maximum likelihood estimates\n");
        }
        else if (structure) {
            mode = 4;
            fprintf(stderr, "Model structures will be compared by integrating over all parameter settings\n");
        }
    } 

    if (output_dir == NULL) quit(-1, "%s: Must specify an output directory\n", rname);
    else fprintf(stderr, "Output directory = %s\n", output_dir);

    if (incremental) fprintf(stderr, "Incremental merging: will start with %d strings, then incrementally add %d strings at a time\n", orig_num_to_add, num_to_add);

    if (vocab_file_path == NULL) quit(-1, "%s: Must specify a vocabulary file\n", rname);
    else fprintf(stderr, "Vocab file = %s\n", vocab_file_path);

    if (closed_vocab) fprintf(stderr, "Closed vocabulary: out-of-vocabulary words will not be handled\n");
    else fprintf(stderr, "Open vocabulary: the unknown word will be added to the vocab, if it isn't already in the vocab file\n");

    if (print_hmm == 0) {
        print_start = 0;
        step_size = 0;
    }
    else fprintf(stderr, "Will start printing hmm at iteration number %d, and will print model every %d iterations\n", print_start, step_size);

    if (obs_file_path != NULL) {
        fprintf(stderr, "Will evaluate error and PP of intermediate models using viterbi decoding\n    and observation file %s, starting\n    when at iteratio number %d and repeating every %d iterations\n", obs_file_path, evaluate_start, evaluate_step);
        evaluate_model = 1;
    }
    else evaluate_model = 0;

    if (smooth > 0 && emissions_dir == NULL && !evaluate_model)
        quit(-1, "%s: -smooth_dists has been specified, but smoothing is only needed if emission distrbutions are printed (-print_dist) or if the models are evaluated on test data (-evaluate_model) ...\n", rname);

    if (emissions_dir != NULL) {
        if (!print_hmm) quit(-1, "%s: must choose -print_model to print output emissions to %s\n", rname, emissions_dir);
        fprintf(stderr, "Will print output emission distributions to %s\n", emissions_dir);
    }

    if (smooth == 1) fprintf(stderr, "Output emission distributions will be smoothed using absolute discounting\n");
    else if (smooth == 2) fprintf(stderr, "Output emission distributions will be smoothed using shrinkage of state counts, prior counts and uniform counts\n");
    else fprintf(stderr, "Output emissions will not be smoothed\n");

    if (same_label && same_label_at_first)
        quit(-1, "%s: should not specify both -same_label and -same_label_at_first\n", rname);

    if (same_label_at_first) same_label = 1;

    if ((initial_adj_collapse  || initial_V_collapse) && !read_label) 
        quit(-1, "%s: must have labels in order to perform initial collapse... set read_label flag.\n", rname);

    if (same_label && !read_label) 
        quit(-1, "%s: must have labels in order to perform same label merging... set read_label flag.\n", rname);

    if (prior_dist_file_path != NULL && !read_label)
        quit(-1, "%s: must have labels in order to use prior distributions... set read_label flag.\n", rname);

    if (read_id && read_count) 
        quit(-1, "%s: cannot read both an id and a count!\n", rname);

    if (prior_dist_file_path != NULL && !same_label && !(initial_collapse_exit || initial_model_exit))
        quit(-1, "%s: should restrict merges to same labels when using preset prior distributions.\n", rname);

    if (eff_sample != 0 && fabs(prior_weight - 1.0) > 0.001)
        quit(-1, "%s: must set either -pw or -eff_sample_size -- cannot set both.\n", rname);

    if (fabs(prior_weight - 1.0) > 0.001) fprintf(stderr, "Prior weight set to %f\n", prior_weight);

    if (structure && structure_prior) quit(-1, "%s: unneccessary to specify both -structure and -structure_prior...\n", rname);
    if (no_prior && structure_prior) quit (-1, "%s: cannot specify both -no_prior and -structure_prior...\n",rname);
    if (no_prior) fprintf(stderr, "Priors will not be used in merging decisions\n");
    if (structure_prior) fprintf(stderr, "Only structure priors will be used in merging decisions\n");

    if (narrow_emission_prior && no_prior) quit(-1, "%s: cannot specify both -no_prior and -narrow_emission_prior...\n", rname);
    if (narrow_emission_prior && structure_prior) quit(-1, "%s: cannot specify both -structure_prior and -narrow_emission_prior...\n", rname);

    if (no_prior) use_prior = 0;
    else if (structure_prior) use_prior = 2;
    else use_prior = 1;

    fprintf(stderr, "Uniform transition alpha prior values are set to %f\n", trans_uni_alpha);
    fprintf(stderr, "Uniform emission alpha prior values are set to %f\n", obs_uni_alpha);
    if (dir_prior_weight > 0) fprintf(stderr, "Prior distributions will be normalized to have a Dirichlet total prior weight of %f\n", dir_prior_weight);

    fprintf(stderr, "-----\n");

    iteration = 0;

    /* Open observation file */
    if (string_file_path != NULL) string_file = kopen_r(string_file_path);
    else string_file = stdin;

    /* Read vocab file */
    vocab_size = read_vocab(vocab_file_path, closed_vocab);
    if (vocab_size < 0) quit(-1, "%s: no vocab words!\n", rname);
    else fprintf(stderr, "Read in fixed vocab of %d words (max id = %d)\n", vocab_size+1, vocab_size);

    /* Read in training data string observations */
    num_strings = 0;
    if (!incremental) {
        strings = read_strings(string_file, &num_strings, read_count, read_label, read_id);
        fclose(string_file);
        orig_num_to_add = num_strings;
    }
    else {
        more_strings = 1;
        strings = (shead **) kmalloc(orig_num_to_add*sizeof(shead *));
        for (i = 0; i < orig_num_to_add; i++) {
            strings[i] = NULL;
        }
        for (i = 0; i < orig_num_to_add; i++) {
            strings[i] = read_one_string(string_file, read_id, read_label, read_count);
            if (strings[i] != NULL) { 
                map_words_to_unk(strings[i]);
                num_strings++;
            }
            if (strings[i] == NULL || feof(string_file)) { 
                more_strings = 0;
                fclose(string_file);
                break; 
            }
        }
        if (num_strings != orig_num_to_add) {
            fprintf(stderr, "%s: warning - only read in %d strings, which is less than the desired number of %d specified\n", rname, num_strings, orig_num_to_add);
        }
        fprintf(stderr, "Incremental: starting with %d strings\n", num_strings);
    }

    /* Adjust prior weight */
    if (eff_sample != 0) {
        prior_weight = ((double) num_strings)/eff_sample;
        fprintf(stderr, "Effective sample size = %d, current number of samples = %d; prior weight set to %f to reflect this\n", eff_sample, num_strings, prior_weight);
    }

    /* Read in preset prior distributions */
    if (prior_dist_file_path != NULL) {
        priors = load_prior_counts(prior_dist_file_path, &num_prior_dists, vocab_size);

        /* Precalculate  log beta factors if a broad emission prior is used and if there are prior dists */
        if (!narrow_emission_prior) 
            calculate_log_beta_factors(priors, num_prior_dists, obs_uni_alpha, vocab_size, dir_prior_weight);
    }
    else {
        priors = NULL;
        num_prior_dists = 0;
    }

    /* Create initial trivial HMM - one state per word in training data - from data */
    create_initial_model_from_data(strings, &num_states, num_strings, trans_uni_alpha, obs_uni_alpha, &initial, &end, priors, num_prior_dists, vocab_size);
    fprintf(stderr, "Finished creating initial HMM with %d states\n", num_states);
    max_state_label = num_states;

    /* Free string structure */
    for (i = 0; i < orig_num_to_add; i++) {
        if (strings[i] != NULL) free_string(strings[i]);
    }
    free(strings);

    /* Create state array */
    states = collect_states(initial, &num_states);

    /* Print model to file and exit if initial_model_exit was set */
    if (initial_model_exit) run_initial_model_exit(output_dir, vocab_file_path, states, num_states, emissions_dir, smooth, mode, narrow_emission_prior, iteration, dir_prior_weight); 

    /* Perform initial collapsing of states */
    if (initial_adj_collapse) states = collapse_adjacent_states(states, &num_states, initial, trans_uni_alpha);
    if (initial_V_collapse) states = collapse_V_states(states, &num_states, initial, end, trans_uni_alpha);

    /* Set parameters of model */
    set_model_parameters(states, num_states, mode, narrow_emission_prior, dir_prior_weight);
    set_model_priors_and_likelihoods(states, num_states, mode, use_prior, narrow_emission_prior);
    fprintf(stderr, "Set parameters of model\n");
   
    /* Print model to file and exit if initial_collapse_exit was set */
    if (initial_collapse_exit) {
        fprintf(stderr, "Printing model to file and exiting...\n");
        print_model_to_file(NULL, output_dir, vocab_file_path, states, num_states, emissions_dir, smooth, mode, iteration);
        exit(1);
    }

    if (print_hmm && (print_start <= iteration) && ((iteration - print_start) % step_size == 0))
        print_model_to_file(NULL, output_dir, vocab_file_path, states, num_states, emissions_dir, smooth, mode, iteration);

    if (evaluate_model && (evaluate_start <= iteration) && ((iteration - evaluate_start) % evaluate_step == 0))
        evaluate_intermediate_model(obs_file_path, output_dir, initial, states, num_states, vocab_size, smooth, mode, iteration);

    /* Compute initial candidate list (containing all candidate pairs) */
    if (same_label_at_first) fprintf(stderr, "Only considering candidate pairs with same label at the beginning - constraint will be lifted later\n");
    else if (same_label) fprintf(stderr, "Only considering candidate pairs with same label\n");
    if (neighbors_only) fprintf(stderr, "Only considering candidate pairs that are neighbors\n");
    candidates = compute_all_candidates(initial, states, num_states, same_label, neighbors_only);
    fprintf(stderr, "Finished computing candidate pairs\n");

    lookahead_count = 0;
    while (1) {

        fprintf(stderr, "\nSTATES = %d, Merge num = %d\n", num_states, iteration);

        num_cands = 0;
        cand_count = 0;
        cand_to_compute = 0;

        /* Count number of candidates... */
        candidate = candidates;
        while (candidate != NULL) {
            num_cands++;
            if (candidate->value > 1000000) cand_to_compute++;
            candidate = candidate->next;
        }
        fprintf(stderr, "Number of candidates = %d (%d to compute)\n", num_cands, cand_to_compute);
        fprintf(stderr, "One . for every 100 computed: ");

        if (DEBUG) {
            fprintf(stderr, "STATE LIST: (%d states)\n", num_states);
            for (i = 0; i < num_states; i++) {
                fprintf(stderr, "i = %d, state id = %d\n", i, states[i]->id);
            }
            print_model(states, num_states, mode);
        }

        /* 1. Compute a set of candidate merges K among the states of the model Mi -- 
		already computed. Make sure candidates still exist */
        /* if (candidates == NULL) break; */

        /* 2. For each candidate k in K, compute the merged model k(Mi), and its
		posterior probability P(k(Mi)|X). */
        candidate = candidates;
        lMAP = -10000000.0;
        max_cand = NULL;
        while (candidate != NULL) {
            trans_to_keep = NULL;

            /* See if value has already been computed for this candidate pair */
            if (candidate->value > 1000000) {   /* Don't have merge value - compute it */
                cand_count++;

                if (DEBUG) fprintf(stderr, "Computing value for %d (%s) and %d (%s)\n", candidate->s1->id, candidate->s1->label, candidate->s2->id, candidate->s2->label);

                if ((cand_count % 100) == 0) fprintf(stderr, ".");
                if ((cand_count % 10000) == 0) fprintf(stderr, "\n");

                /* Compute value for candidate states before merge */
                existing_states_contribution = compute_candidate_contribution(candidate->s1, candidate->s2, num_states, prior_weight, use_prior, mode);

                if (DEBUG) fprintf(stderr, "Existing states contribution: %f\n", existing_states_contribution);

                /* Merge the states and compute value for new state */
                new_state = merge_states(candidate->s1, candidate->s2, &trans_to_keep, trans_uni_alpha);
                update_model_parameters(new_state, mode, narrow_emission_prior, dir_prior_weight);
                update_model_priors_and_likelihoods(new_state, num_states-1, mode, use_prior, narrow_emission_prior);
                new_state_contribution = compute_new_state_contribution(new_state, num_states-1, prior_weight, use_prior, mode);
               
                if (DEBUG) fprintf(stderr, "New state contribution: %f\n", new_state_contribution);

                candidate->value = new_state_contribution - existing_states_contribution;
                if (DEBUG) fprintf(stderr, "Candidate value = %f\n", candidate->value);

                unmerge_states(new_state, trans_to_keep);
                update_model_parameters(candidate->s1, mode, narrow_emission_prior, dir_prior_weight);
                update_model_parameters(candidate->s2, mode, narrow_emission_prior, dir_prior_weight);
                update_model_priors_and_likelihoods(candidate->s1, num_states, mode, use_prior, narrow_emission_prior);
                update_model_priors_and_likelihoods(candidate->s2, num_states, mode, use_prior, narrow_emission_prior);
            }
            if (DEBUG) fprintf(stderr, "CANDIDATE PAIR: %d (%s) and %d (%s), value = %f\n", candidate->s1->id, candidate->s1->label, candidate->s2->id, candidate->s2->label, candidate->value);

            if (candidate->value > lMAP) {
                lMAP = candidate->value;
                max_cand = candidate;
            }

            candidate = candidate->next;
        }
     
        fprintf(stderr, "\n");
        if (candidates != NULL) 
            fprintf(stderr, "BEST CANDIDATE PAIR: %d (%s) and %d (%s), candidate value = %f\n", max_cand->s1->id, max_cand->s1->label, max_cand->s2->id, max_cand->s2->label, lMAP);
 
        /* 4. If P(Mi+1|X) < P(Mi|X), break from the loop. */

        if (lMAP > 0 || (lookahead_count < lookahead_threshold && (!incremental || (incremental && !more_strings)) && candidates != NULL)) { /* merge is good */
            if (lMAP < 0) lookahead_count++;
            else {
                lookahead_count = 0;
                last_good_num_states = num_states - 1;
                last_good_iteration = iteration + 1;
            }

            fprintf(stderr, "MERGING STATES\n");
 
            /* 3. Let k* be the merge that maximized P(k(Mi)|X). Then let Mi+1 = k*(Mi). */
            trans_to_keep = NULL;
            new_state = merge_states(max_cand->s1, max_cand->s2, &trans_to_keep, trans_uni_alpha);
            update_model_parameters(new_state, mode, narrow_emission_prior, dir_prior_weight);
            update_model_priors_and_likelihoods(new_state, num_states-1, mode, use_prior, narrow_emission_prior);

            /* Free merged states, the transition list and the candidate list */
            if (DEBUG) fprintf(stderr, "Updating candidate list\n");
            update_candidates(&candidates, max_cand, states, num_states, new_state, same_label, neighbors_only);
            num_states--;

            if (DEBUG) fprintf(stderr, "Freeing transition list\n");
            free_trans_list(trans_to_keep); /* Since we don't call unmerge_states */

        	/* 5. i = i+1 */
            iteration++;
            if (print_hmm && (print_start <= iteration) && ((iteration - print_start) % step_size == 0))
                print_model_to_file(NULL, output_dir, vocab_file_path, states, num_states, emissions_dir, smooth, mode, iteration);

            if (evaluate_model && (evaluate_start <= iteration) && ((iteration - evaluate_start) % evaluate_step == 0))
                evaluate_intermediate_model(obs_file_path, output_dir, initial, states, num_states, vocab_size, smooth, mode, iteration);
        }
        else { /* Merge is not good */

            fprintf(stderr, "NO MERGE\n");

            if (incremental && more_strings) {

                add_incremental_strings(string_file, &num_states, initial, end, read_id, read_label, read_count, &max_state_label, num_to_add, priors, num_prior_dists, vocab_size, trans_uni_alpha, obs_uni_alpha, &more_strings, &num_added);

                if (num_added == 0) more_strings = 0;
                else {

                    num_strings += num_added;
                    if (eff_sample != 0) {
                        prior_weight = ((double) num_strings)/eff_sample;
                        fprintf(stderr, "Effective sample size = %d, current number of samples = %d; prior weight set to %f to reflect this\n", eff_sample, num_strings, prior_weight);
                    }

                    free(states);
                    states = collect_states(initial, &num_states);

                    if (initial_adj_collapse) states = collapse_adjacent_states(states, &num_states, initial, trans_uni_alpha);
                    if (initial_V_collapse) states = collapse_V_states(states, &num_states, initial, end, trans_uni_alpha);
                    set_model_parameters(states, num_states, mode, narrow_emission_prior, dir_prior_weight);
                    if (DEBUG) fprintf(stderr, "Set model parameters\n");
                    set_model_priors_and_likelihoods(states, num_states, mode, use_prior, narrow_emission_prior);
                    free_candidates(candidates);
                    candidates = compute_all_candidates(initial, states, num_states, same_label, neighbors_only);
                }
            }
            else if (same_label_at_first) {
                same_label = 0;
                same_label_at_first = 0;

                fprintf(stderr, "Removing same_label constraint...\n");

                free_candidates(candidates);
                candidates = compute_all_candidates(initial, states, num_states, same_label, neighbors_only);
            }
            else break; /* Call it quits */
        }
    }

    /* Return Mi as the induced model */
    if (!print_hmm || ((iteration - print_start) % step_size != 0)) 
        print_model_to_file(NULL, output_dir, vocab_file_path, states, num_states, emissions_dir, smooth, mode, iteration);
    fprintf(stderr, "\nWrote final model to file %s/iter_%d_%d.model\n", output_dir, iteration, num_states);

    if (evaluate_model && (evaluate_start <= iteration) && ((iteration - evaluate_start) % evaluate_step != 0))
        evaluate_intermediate_model(obs_file_path, output_dir, initial, states, num_states, vocab_size, smooth, mode, iteration);

    fprintf(stderr, "FINAL MODEL: %s/iter_%d_%d.model.gz\n", output_dir, last_good_iteration, last_good_num_states);

    /* Free everything */
    free(states);
    states = collect_states(initial, &num_states);
    for (i = 0; i < num_states; i++) {
        if (states[i] != NULL) free_state(states[i]);
    }
    free(states);
    free_candidates(candidates);

    fprintf(stderr, "DONE!\n");
}


