/* evaluate.c
# ---------------------------------------------------------------------------------
# Hidden Markov Model Evaluation code
# ----------------------------------------------------------------------------------

See documentation for code in /homes/kseymore/ra/hmm/doc/hmm_tools.txt

*/

#include "general.h"

const char *argp_program_version = "evaluate_1.0";

main (int argc, char **argv) {
    static char rname[]="evaluate";
    FILE	*output_file, *string_file;
    int		i, num_states, num_strings, vit_details, print_state_id;
    int		string_count, need_help, forward_details, punc_trans;
    int         run_vit, run_forward, print_hmm, obs_with_id, print_probs;
    int		server_mode, max_class, calc_pp, total_obs, num_obs, read_label;
    double 	flprob, total_obs_logprob, pf_logprob, pp, elprob, tlprob;
    double	total_e_logprob, total_t_logprob, pe_logprob, pt_logprob, pp_e, pp_t;
    char 	*model_file_path, *data_file_path, *output_file_path;
    char	*port_num;
    shead       *start;
    state	*hmm, **states;
    path_head	*vit_path, *max_path;
    float       trans_weight;

    num_states = 0;
    num_strings = 0;

    elprob = 0.0;
    tlprob = 0.0;

    /* Process command line arguments */

    need_help = read_com_noarg(&argc, argv, "-help");
    if (argc == 1 || need_help) {

        /* Display help message */
        fprintf(stderr,"evaluate : Given an HMM, find the optimal state sequences or probabilities of some observation sequences.\n");
        fprintf(stderr,"Usage: evaluate -hmm <model_file>\n");
        fprintf(stderr,"                -vit | -forward | -max_class\n");
        fprintf(stderr,"              [ -obs <observation_file> ]\n");
        fprintf(stderr,"              [ -out <output_file> ]\n");
        fprintf(stderr,"              [ -print_probs ]\n");
        fprintf(stderr,"              [ -calc_pp ]\n");
        fprintf(stderr,"              [ -trans_weight <weight> ]\n");
        fprintf(stderr,"              [ -print_model ]\n");
        fprintf(stderr,"              [ -print_state_id ]\n");
        fprintf(stderr,"              [ -read_obs_id ]\n");
        fprintf(stderr,"              [ -only_punc_trans ]\n");
        fprintf(stderr,"              [ -read_labels ]\n");
        fprintf(stderr,"              [ -server_mode <port_num>]\n");
        fprintf(stderr,"              [ -vit_details ]\n");
        fprintf(stderr,"              [ -forward_details ]\n\n");
        fprintf(stderr,"Documentation available at /homes/kseymore/ra/hmm/doc/hmm_tools.txt\n\n");
        exit(1);
    }

    model_file_path = read_com_string(&argc, argv, "-hmm");
    data_file_path = read_com_string(&argc, argv, "-obs");
    output_file_path = read_com_string(&argc, argv, "-out");
    if (!read_com_float(&argc, argv, "-trans_weight", &trans_weight)) trans_weight = 1.0;
    port_num = read_com_string(&argc, argv, "-server_mode");
    run_vit = read_com_noarg(&argc, argv, "-vit");
    max_class = read_com_noarg(&argc, argv, "-max_class");
    print_probs = read_com_noarg(&argc, argv, "-print_probs");
    run_forward = read_com_noarg(&argc, argv, "-forward");
    calc_pp = read_com_noarg(&argc, argv, "-calc_pp");
    print_hmm = read_com_noarg(&argc, argv, "-print_model");
    print_state_id = read_com_noarg(&argc, argv, "-print_state_id");
    obs_with_id = read_com_noarg(&argc, argv, "-read_obs_id");
    punc_trans = read_com_noarg(&argc, argv, "-only_punc_trans");
    read_label = read_com_noarg(&argc, argv, "-read_labels");
    vit_details = read_com_noarg(&argc, argv, "-vit_details");
    forward_details = read_com_noarg(&argc, argv, "-forward_details");
    check_extra_args(&argc, argv);
    server_mode = (port_num == NULL) ? 0 : 1;

    /* Print out command line arguments and check for consistency */
    if (model_file_path == NULL) quit(-1, "%s: no HMM file specified - exiting...\n", rname);
    else fprintf(stderr, "Model file path = %s\n", model_file_path);

    if (!run_vit && !run_forward && !max_class)
        quit(-1, "%s: either the viterbi search, the forward procedure, or max class must be chosen.\n", rname);
    if (run_vit && max_class) quit(-1, "%s: -run_vit and -max_class can not both be specified\n",rname);
    if (run_vit) fprintf(stderr, "Viterbi search chosen\n");
    if (print_state_id && !run_vit) quit(-1, "%s: -print_state_id can only be selected with the -vit option.\n", rname);
    if (print_state_id) fprintf(stderr, "Will print state ids on output\n");
    if (run_forward) fprintf(stderr, "Forward procedure chosen\n");
    if (calc_pp) fprintf(stderr, "Perplexity will be calculated\n");
    if (print_probs) fprintf(stderr, "Transition and emission log probabilities will be printed\n");
    if (max_class) fprintf(stderr, "Maximum class search chosen\n");
    if (print_hmm) fprintf(stderr, "Will print the Hidden Markov model to stderr\n");
    if (obs_with_id) fprintf(stderr, "Expecting integer ids to precede each observation sequence\n");
    if (punc_trans) fprintf(stderr, "Permitting state transitions only after words ending in punctuation [.,:!]\n");
    if (read_label) fprintf(stderr, "Reading labels with observations - will be used to constrain search (forced alignment)\n");
    if (trans_weight < 1.0 || trans_weight > 1.0) fprintf(stderr, "Using a transition weight of %f\n", trans_weight);
    if (server_mode) fprintf(stderr, "Running in server mode using port number %s\n", port_num);

    if (!server_mode) {
        if (data_file_path == NULL) fprintf(stderr, "No observation file specified - using stdin...\n");
        else fprintf(stderr, "Observation file path = %s\n", data_file_path);

        if (output_file_path == NULL) fprintf(stderr, "No output file specified - using stdout\n");
        else fprintf(stderr, "Output file path = %s\n", output_file_path);
    }

    /* Read in model */
    fprintf(stderr, "Reading in HMM...\n");
    hmm = read_model_from_file(model_file_path, &num_states);  

    /* Create state array */
    states = (state **) kmalloc((num_states+1) * sizeof(state *));
    for (i = 0; i < num_states; i++) {
        states[i] = NULL;
    }
    collect_children_of(hmm, num_states, states);


    if (print_hmm) {
        print_model(states, num_states, 0);
    }

    if (server_mode) {
        rainbow_socket_init (port_num, 0);
        while (1) {
            rainbow_serve(hmm, obs_with_id, run_vit, print_probs, run_forward, vit_details, forward_details, punc_trans, max_class, states, num_states, read_label, trans_weight, print_state_id);
        }
    }
    else {

        /* Open observation file */
        if (data_file_path != NULL) string_file = kopen_r(data_file_path);
        else string_file = stdin;

        /* Open output file */
        if (output_file_path != NULL) output_file = kopen_w(output_file_path);
        else output_file = stdout;

        fprintf(stderr, "Processing observations - printing one '.' for each 100 strings processed\n");

        string_count = 0;
        total_obs = 0;
        total_obs_logprob = 0.0;
        total_e_logprob = 0.0;
        total_t_logprob = 0.0;

        /* Read first line from string file */
        start = read_one_string(string_file, obs_with_id, read_label, 0);

        /* Tag strings as long as there are strings in file */
        while (start != NULL) {

            if (max_class) {

                /* Find max class for each observation in string */
                max_path = find_max_path(start, states, num_states, read_label);
                print_path_to_file(max_path, output_file, start, obs_with_id, print_state_id, print_probs); 
                free_path(max_path);
            }

            if (run_vit) {

                /* Find Viterbi path for string */
                vit_path = find_vit_path(start, hmm, vit_details, punc_trans, read_label, trans_weight);
                print_path_to_file(vit_path, output_file, start, obs_with_id, print_state_id, print_probs);
                if (calc_pp) collect_pp_stats(vit_path, &total_obs, &total_obs_logprob, &total_t_logprob, &total_e_logprob);
                free_path(vit_path);
                fflush(output_file);
            }

            if (run_forward) {

                /* Calculate forward path probability */
                flprob = calc_forward_prob(start, hmm, forward_details, punc_trans, read_label, trans_weight, &elprob, &tlprob);
                print_forward_prob_to_file(flprob, output_file, start, obs_with_id);

                if (calc_pp) {
                    num_obs = count_num_symbols(start);
                    total_obs += (num_obs-1);
                    total_obs_logprob += flprob;
                    total_e_logprob += elprob;
                    total_t_logprob += tlprob;
                }
            }

            free_string(start);
            string_count++;

            if ((string_count % 100) == 0) fprintf(stderr, ".");
            if ((string_count % 1000) == 0) fprintf(stderr, "\n");

            /* Read in next line from string file */
            start = read_one_string(string_file, obs_with_id, read_label, 0);

        }
        fprintf(stderr, "\n");
        fclose(output_file);
        fclose(string_file);

        if (calc_pp) {
            pe_logprob = total_e_logprob / total_obs;
            pp_e = exp(-pe_logprob);

            pt_logprob = total_t_logprob / total_obs;
            pp_t = exp(-pt_logprob);

            pf_logprob = total_obs_logprob / total_obs;
            pp = exp(-pf_logprob);

            fprintf(stderr, "TOTAL OBSERVATION LOGPROB: %f (%f / %d) PP = %.2f\n", pf_logprob, total_obs_logprob, total_obs, pp);
            fprintf(stderr, "    transition part: %f (%f / %d) PP = %.2f\n", pt_logprob, total_t_logprob, total_obs, pp_t);
            fprintf(stderr, "    emission part: %f (%f / %d) PP = %.2f\n", pe_logprob, total_e_logprob, total_obs, pp_e);
        }

        fprintf(stderr, "Processed %d observation sequences\n", string_count);
        fprintf(stderr, "DONE!\n");
    }
}

