/* train_hmm.c
---------------------------------------------------------------------------------
Hidden Markov Model Parameter Training code
----------------------------------------------------------------------------------

See documentation for code in /homes/kseymore/ra/hmm/doc/hmm_tools.txt

*/

#include "general.h"

const char *argp_program_version = "train_hmm_1.0";

main (int argc, char **argv) {
    static char rname[]="train_hmm";
    FILE	*output_file, *string_file;
    int		i, need_help, num_states, num_strings, details, closed_vocab;
    int		read_count, read_label, read_id, vocab_size, punc_trans;
    int		uniform, random, seed, trans_only, mode, smooth, iteration;
    char        *model_file_path, *data_file_path, *output_file_path, *vocab_file_path;
    char	*emissions_dir;
    state       *hmm, *new_model, **states, *last_state;
    shead       **strings;


    num_states = 0;
    num_strings = 0;
    read_count = 0;

    /* Process command line arguments */

    need_help = read_com_noarg(&argc, argv, "-help");
    if (argc == 1 || need_help) {

        /* Display help message */
        fprintf(stderr,"train_hmm : Given an HMM with initial parameter estimates and observation data, iterate forward-backward training until local maximum likelihood parameter estimates are attained.\n");
        fprintf(stderr,"Usage: train_hmm -hmm <model_file>\n");
        fprintf(stderr,"                 -vocab <vocab_file>\n");
        fprintf(stderr,"               [ -closed_vocab ]\n");
        fprintf(stderr,"               [ -obs <observation_file> ]\n");
        fprintf(stderr,"               [ -out <output_file> ]\n");
        fprintf(stderr,"               [ -print_new_emissions <directory> ]\n");
        fprintf(stderr,"               [ -read_label ]\n");
        fprintf(stderr,"               [ -read_id ]\n");
        fprintf(stderr,"               [ -only_punc_trans ]\n");
        fprintf(stderr,"               [ -uniform ]\n");
        fprintf(stderr,"               [ -random <seed> ]\n");
        fprintf(stderr,"               [ -trans_only ]\n");
        fprintf(stderr,"               [ -details ]\n\n");
        fprintf(stderr,"Documentation available at /homes/kseymore/ra/hmm/doc/hmm_tools.txt\n\n");
        exit(1);
    }

    model_file_path = read_com_string(&argc, argv, "-hmm");
    vocab_file_path = read_com_string(&argc, argv, "-vocab");
    closed_vocab = read_com_noarg(&argc, argv, "-closed_vocab");
    data_file_path = read_com_string(&argc, argv, "-obs");
    output_file_path = read_com_string(&argc, argv, "-out");
    emissions_dir = read_com_string(&argc, argv, "-print_new_emissions");
    read_label = read_com_noarg(&argc, argv, "-read_label");
    read_id = read_com_noarg(&argc, argv, "-read_id");
    punc_trans = read_com_noarg(&argc, argv, "-only_punc_trans");
    uniform = read_com_noarg(&argc, argv, "-uniform");
    trans_only = read_com_noarg(&argc, argv, "-trans_only");
    random = read_com_int(&argc, argv, "-random", &seed);
    details = read_com_noarg(&argc, argv, "-details");
    check_extra_args(&argc, argv);

    /* Print out command line arguments and check for consistency */
    if (model_file_path == NULL) quit(-1, "%s: no HMM file specified.\n", rname);
    else fprintf(stderr, "Model file path = %s\n", model_file_path);

    if (punc_trans) fprintf(stderr, "Permitting state transitions only after words ending in punctuation [.,:!]\n");

    if (data_file_path == NULL) fprintf(stderr, "No observation file specified - using stdin...\n");
    else fprintf(stderr, "Observation file path = %s\n", data_file_path);

    if (output_file_path == NULL) fprintf(stderr, "No output file specified - using stdout\n");
    else fprintf(stderr, "Output file path = %s\n", output_file_path);

    if (vocab_file_path == NULL) quit(-1, "%s: must specify a vocabulary file!\n", rname);
    else fprintf(stderr, "Will use fixed vocab as specified in %s\n", vocab_file_path);

    if (read_id) fprintf(stderr, "Expecting to read an integer id before each observatino string\n");
    if (read_label) fprintf(stderr, "Expecting each observation word to be followed by a label\n");

    if (uniform && random) quit(-1, "%s: cannot initialize parameters both uniformly and randomly\n", rname);
    
    if (!uniform && !random) {
        fprintf(stderr, "Initial transition parameter estimates will be taken from model file\n");
    }
    if (random) {
        fprintf(stderr, "Random transition parameter estimates will be selected using seed %d\n", seed);
        random = seed;
    }
    if (uniform) {
        fprintf(stderr, "Transition parameter estimates will be initialized with uniform probabilities.\n");
    }

    if (trans_only && emissions_dir != NULL) quit(-1, "%s: do not specify an emissions directory if only transitions are being reestimated\n", rname);
    if (trans_only) fprintf(stderr, "Will only reestimate transition parameters -- emission values will be held constant.\n");
    if (emissions_dir != NULL) fprintf(stderr, "Will print reestimated emission distributions to directory %s\n", emissions_dir);

    /* Read in model */
    fprintf(stderr, "Reading in HMM...\n");
    hmm = read_model_from_file(model_file_path, &num_states);  

    /* Open observation file */
    if (data_file_path != NULL) string_file = kopen_r(data_file_path);
    else string_file = stdin;

    /* Open output file */
    if (output_file_path != NULL) output_file = kopen_w(output_file_path);
    else output_file = stdout;

    /* Read vocab file */
    vocab_size = read_vocab(vocab_file_path, closed_vocab);
    if (vocab_size < 0) quit(-1, "%s: no vocab words!\n",rname);
    else fprintf(stderr, "Read in fixed vocab of %d words\n", vocab_size+1);

    /* Read in training data string observations */
    strings = read_strings(string_file, &num_strings, read_count, read_label, read_id);
    fprintf(stderr, "Read in strings\n");
    fclose(string_file);
 
    /* Iterate Baum-Welch training until convergence */
    fprintf(stderr, "Iterating BW\n");
    new_model = iterate_bw(strings, num_strings, hmm, num_states, details, punc_trans, uniform, random, vocab_size, trans_only, emissions_dir, output_file);
    fprintf(stderr, "Exited BW\n");

    /* If there is an emissions dir, model and emission distributions will be printed right after 
	bw is finished. Otherwise, print the model now. */
    if (emissions_dir == NULL) {

        /* Create state array */
        states = collect_states(new_model, &num_states);

        /* Print new model to file */
        mode = 1;
        smooth = 0;
        iteration = 0;
        print_model_to_file(output_file, NULL, vocab_file_path, states, num_states, emissions_dir, smooth, mode, iteration);
    }

    fclose(output_file);
    fprintf(stderr, "DONE!\n");
}


