/* lm.c */

#include "general.h" 

multinomial *read_arpa_1gram(char *dist_file_path, int closed_vocab) {
    static char rname[]="read_arpa_1gram";
    FILE *dist_file, *vocab_file;
    multinomial *dist;
    tp *temp, *prev, *head;
    int num_1gs, i, word_id, max_id, num_words, num_read, found_zeroton_prob, num_zeroton_words;
    char *word, *line, *temp1, *temp2, *temp3, *temp4, *temp5, *vocab_file_path;
    double logprob_ten, logprob_e, zeroton_l10_prob, zeroton_lprob, zeroton_prob, total_prob;
    double  log_ten_of_e = 1.0 / log(10.0);
    static int line_length = 1024;
    static word_length = 100;

    dist = (multinomial *) kmalloc(sizeof(multinomial));
    line = (char *) kmalloc(line_length*sizeof(char));
    temp1 = (char *) kmalloc(word_length*sizeof(char));
    temp2 = (char *) kmalloc(word_length*sizeof(char));
    temp3 = (char *) kmalloc(word_length*sizeof(char));
    temp4 = (char *) kmalloc(word_length*sizeof(char));
    temp5 = (char *) kmalloc(word_length*sizeof(char));
    word = (char *) kmalloc(word_length*sizeof(char));
    head = NULL;
    prev = NULL;
    temp = NULL;
    found_zeroton_prob = 0;
    num_zeroton_words = 0;
    vocab_file_path = NULL;

    dist_file = kopen_r(dist_file_path);

    if (DEBUG) fprintf(stderr, "Opened distribution file...\n");

    /* Read in file, look for \data\ marker */
    fgets(line,line_length,dist_file);
    while (strncmp(line, "\\data\\",6)) {

        if ((sscanf(line,"%s %s %s %s %s", temp1, temp2, temp3, temp4, temp5) == 5) &&
            !strcmp("#",temp1) &&
            !strcmp("Corresponding",temp2) &&
            !strcmp("vocab",temp3) &&
            !strcmp("=",temp4)) {

            vocab_file_path = strdup(temp5);
            if (DEBUG) fprintf(stderr, "Read in vocab path file of %s\n", vocab_file_path);
        }

        if ((sscanf(line,"%s %s %s %s %lf",temp1, temp2, temp3, temp4, &zeroton_l10_prob) == 5) &&
            !strcmp("#",temp1) &&
            !strcmp("zeroton",temp2) &&
            !strcmp("log10_prob",temp3) &&
            !strcmp("=",temp4)) {
                found_zeroton_prob = 1;
                zeroton_lprob = zeroton_l10_prob / log_ten_of_e;
                zeroton_prob = exp(zeroton_lprob);
                if (DEBUG) fprintf(stderr, "Read in zeroton log10 prob of %f, log prob = %f, prob = %.10f\n", zeroton_l10_prob, zeroton_lprob, zeroton_prob);
        }

        if (feof(dist_file))
            quit(-1, "%s: reached end of lm file without finding \\data\\ marker...\n", rname);
        fgets(line,line_length,dist_file);
    }

    /* Read in number of n-gram types */
    i = 0;
    fgets(line,line_length,dist_file);
    while (strncmp("\\1-grams",line,8)) {
        if (sscanf(line,"%s %s",temp1,temp2) == 2) {
            if (!strcmp("ngram",temp1)) {
                i = temp2[0]-48;
                if (i == 1) {
                    num_1gs = atoi(&(temp2[2]));
                }
            }
        }
        if (feof(dist_file))
            quit(-1, "%s: reached end of lm file before \\1-grams marker...\n",rname);
        fgets(line,line_length,dist_file);
    }
    free(temp1);
    free(temp2);
    free(temp3);
    free(temp4);
    free(temp5);

    if (i == 0)
        quit(-1, "%s: parsing error in arpa lm %s\n", rname, dist_file_path);

    fprintf(stderr, "Reading %d unigrams - printing one '.' for each 10k words read\n", num_1gs);

    if (vocab_file_path != NULL) max_id = read_vocab(vocab_file_path, closed_vocab);
    else max_id = -1;

    num_read = 0;
    fgets(line,line_length,dist_file);
    if (line == NULL)
        quit(-1, "%s: no line was read from %s\n",rname, dist_file_path);

    /* Read in words and probabilities to a temporary linked list */
    while (line != NULL && strncmp("\\end\\",line,5)) {
        if ((sscanf(line,"%lf %s", &logprob_ten, word)) == 2) {

            num_read++;
            if ((num_read % 10000) == 0) { 
                fprintf(stderr, ".");
            }

            logprob_e = logprob_ten/log_ten_of_e;

            word_id = bow_word2int(word);

            if (word_id > max_id) max_id = word_id;

            temp = (tp *) kmalloc(sizeof(tp));
            temp->id = word_id;
            temp->lprob = logprob_e;
            temp->next = NULL;

            if (head == NULL) head = temp;
            if (prev != NULL) prev->next = temp;
            prev = temp;

        }
        else {
            if (strlen(line)>1) {
                fprintf(stderr,"WARNING: reading line %s gave unexpected input\n",line);
            }
        }
        if (feof(dist_file)) {   
            if (DEBUG) fprintf(stderr, "Reached EOF for %s\n", dist_file_path);
            break;
        }
        fgets(line,line_length,dist_file);
    }
    fclose(dist_file);

    if (num_read == 0) quit(-1, "%s: no word read - exiting...\n", rname);
    else fprintf(stderr, "\n");

    /* Make sure unk_word is included */
    if (bow_word2int(unk_word) > max_id) max_id = bow_word2int(unk_word);

    /* Once all words are read in, create actual distribution structure */
    dist->lprobs = (float *) kmalloc((max_id+1)*sizeof(float));
    for (i = 0; i <= max_id; i++) {
        dist->lprobs[i] = no_prob;
    }

    num_words = 0;
    total_prob = 0.0;
    temp = head;
    while (temp != NULL) {
       dist->lprobs[temp->id] = temp->lprob;
       total_prob += (exp(temp->lprob));
       num_words++;
       prev = temp;
       temp = temp->next;
       free(prev);
    }
    if (DEBUG) fprintf(stderr, "Total prob mass from arpa file is %f\n", total_prob);

    /* If the zeroton log prob is listed in the lm file, load zeroton words now: */
    if (found_zeroton_prob && vocab_file_path != NULL) {

        if (DEBUG) fprintf(stderr, "Now reading in vocab file to assign zeroton probs\n");

        total_prob = 0.0;
        vocab_file = kopen_r(vocab_file_path);
        while (fgets(line, 1000, vocab_file) != NULL) {
            word = strtok(line, " \n");
            if (word[0] == '#') continue;

            while (word != NULL) {
                word_id = bow_word2int(word);

                if (dist->lprobs[word_id] > 0) { /* hence, no prob has been assigned */
                    dist->lprobs[word_id] = zeroton_lprob;
                    total_prob += zeroton_prob;
                    num_zeroton_words++;
                }

                word = strtok(NULL, " \n");
            }
        }
        fclose(vocab_file);
        if (DEBUG) fprintf(stderr, "Assigned zeroton prob to %d words\n", num_zeroton_words);
        if (DEBUG) fprintf(stderr, "Total zeroton prob mass is %f\n", total_prob);
    }

    /* Calculate the total prob value */
    total_prob = 0.0;
    for (i = 0; i <= max_id; i++) {
        if (dist->lprobs[i] <= 0) 
            total_prob += exp(dist->lprobs[i]);
    }
    if (DEBUG) fprintf(stderr, "total probability for this distribution is %f\n", total_prob);

    if (total_prob < 1.0 && found_zeroton_prob) {

        /* Check for prob of unk_word  - if unassigned, check to see if total prob 
           is missing one zeroton prob value. If so, then assign the unk_word the 
           zeroton prob value. Otherwise leave the unk_word prob as zero */

        if (dist->lprobs[bow_word2int(unk_word)] > 0) {
            dist->lprobs[bow_word2int(unk_word)] = zeroton_lprob;
            total_prob += zeroton_prob;
            if (DEBUG) fprintf(stderr, "assigned the zeroton prob value to the unk word: total prob = %f\n", total_prob);
        }
    }
    if (fabs(1.0 - total_prob) > 0.0001) {
        fprintf(stderr, "%s: warning - total probability for this distribution is %f\n", rname, total_prob);
    }


    dist->label = NULL;
    dist->path = NULL;
    dist->num_types = num_words;
    dist->vocab_size = max_id;
    dist->num_tokens = -1; /* For arpa LMs, since we do not have count information */
    dist->prior = NULL;
    dist->counts = NULL;
    dist->uni_alpha = 0;
    dist->total_count = 0;
    dist->prior_weight_adjustment = 1.0;

    if (num_words != num_1gs) {
        fprintf(stderr, "WARNING: num_words (%d) != num_1gs (%d)\n", num_words, num_1gs);
    }

    if (vocab_file_path != NULL) free(vocab_file_path);
    free(line);
    free(word);

    return(dist);
}

multinomial **load_dists(char *dist_file_path, int *p_num_dists) {
    static char rname[]="load_dists";
    FILE *dist_file;
    int num_dists, i;
    char buffer[10000], dist_label[100], dist_path[500];
    multinomial **dists;

    dist_file = kopen_r(dist_file_path);

    num_dists = 0;
    while (fgets(buffer, 10000, dist_file)) {
        num_dists++;
    }
    if (DEBUG) fprintf(stderr, "%d lines read from distribution file %s\n", num_dists, dist_file_path);

    /* Initialize distribution data structure */
    dists = (multinomial **) kmalloc(num_dists*sizeof(multinomial *));
    for (i = 0; i < num_dists; i++) {
        dists[i] = NULL;
    }

    /* Read in distribution file name and load distribution */
    if (DEBUG) fprintf(stderr, "Reading distribution path names...\n");
    rewind(dist_file);
    for (i = 0; i < num_dists; i++) {
        fscanf(dist_file, "%s %s\n", dist_label, dist_path);
        if (DEBUG) fprintf(stderr, "label = %s, file = %s\n", dist_label, dist_path);
        dists[i] = create_distribution(dist_label, dist_path, NULL, 0, NULL, 0, 0, 0);
    }

    *p_num_dists = num_dists;
    return(dists);

}

count_dist **load_prior_counts(char *dist_file_path, int *p_num_dists, int vocab_size) {
    static char rname[]="load_prior_counts";
    FILE *dist_file;
    int num_dists, i;
    char buffer[10000], dist_label[100], dist_path[500];
    count_dist **priors;

    dist_file = kopen_r(dist_file_path);

    num_dists = 0;
    while (fgets(buffer, 10000, dist_file)) num_dists++;
    if (DEBUG) fprintf(stderr, "%d lines read from distribution file %s\n", num_dists, dist_file_path);

    /* Initialize distribution data structure */
    priors = (count_dist **) kmalloc(num_dists*sizeof(count_dist *));
    for (i = 0; i < num_dists; i++) {
        priors[i] = NULL;
    }

    fprintf(stderr, "Reading in priors...\n");

    /* Read in distribution file name and load distribution */
    rewind(dist_file);
    for (i = 0; i < num_dists; i++) {
        fscanf(dist_file, "%s %s\n", dist_label, dist_path);
        if (DEBUG) fprintf(stderr, "label = %s, file = %s\n", dist_label, dist_path);

        priors[i] = read_counts(dist_path, dist_label, vocab_size);
    }
    fclose(dist_file);

    *p_num_dists = num_dists;

    return(priors);

}

count_dist *read_counts(char *dist_path, char *dist_label, int vocab_size) {
    static char rname[]="read_counts";
    FILE *count_file;
    int i, count, word_id;
    char word[200];
    count_dist *prior;
    tc *head, *prev, *temp;

    /* Reads in integer count files, and stores the counts as doubles */

    prior = (count_dist *) kmalloc(sizeof(count_dist));
    prior->label = strdup(dist_label);
    prior->path = strdup(dist_path);
    prior->num_types = 0;
    prior->num_tokens = 0;
    prior->vocab_size = vocab_size;
    prior->log_beta = 0;

    prior->alphas = (double *) kmalloc((vocab_size+1)*sizeof(double));
    for (i = 0; i <= vocab_size; i++) {
        prior->alphas[i] = 0;
    }

    count_file = kopen_r(dist_path);

    while (!feof(count_file)) {
        fscanf(count_file, "%s %d\n", word, &count);

        if (word[0] == '#') continue;

        word_id = bow_word2int_no_add(word);

        /* Map OOV words to the unknown word, if the unknwon word is in vocab */
        if (word_id == -1) word_id = bow_word2int_no_add(unk_word);
        if (word_id == -1) continue;

        prior->alphas[word_id] += (double) count;
        prior->num_types++;
        prior->num_tokens += count;
    }

    return(prior);
}

void calculate_log_beta_factors(count_dist **priors, int num_dists, double obs_uni_alpha, int vocab_size, double dir_prior_weight) {
    int i, j;
    double *new_alphas, total_prior_weight, prior_weight_adjustment;

    /* This should only be called with broad emission priors */

    new_alphas = (double *) kmalloc((vocab_size+1)*sizeof(double));

    /* Precompute the log beta value for each prior distribution */
    for (i = 0; i < num_dists; i++) {

        if (dir_prior_weight > 0) {
            total_prior_weight = 0;
            for (j = 0; j <= vocab_size; j++) {
                total_prior_weight += (obs_uni_alpha + priors[i]->alphas[j]);
            }
            prior_weight_adjustment = total_prior_weight / dir_prior_weight;
        }
        else prior_weight_adjustment = 1.0;

        for (j = 0; j <= vocab_size; j++) {
            new_alphas[j] = (obs_uni_alpha + priors[i]->alphas[j]) / prior_weight_adjustment;
        }
        priors[i]->log_beta = log_beta(new_alphas, vocab_size+1);

        if (DEBUG) fprintf(stderr, "Log beta value for prior %s = %f\n", priors[i]->label, priors[i]->log_beta);
    }
}

count_dist *matching_prior(char *label, count_dist **priors, int num_prior_dists) {
    static char rname[]="matching_prior";
    int i;

    for (i = 0; i < num_prior_dists; i++) {
        if (!strcmp(priors[i]->label, label)) return (priors[i]);
    }
   
    return(NULL);
}

int get_max_id(char *dist_path) {
    static char rname[]="get_max_id";
    FILE *count_file;
    int word_id, max_id, count;
    char word[200];

    count_file = kopen_r(dist_path);

    max_id = 0;
    while (!feof(count_file)) {
        fscanf(count_file, "%s %d\n", word, &count);

        if (word[0] == '#') continue;

        word_id = bow_word2int_no_add(word);
        if (word_id > max_id) max_id = word_id;
    }
    fclose(count_file);

    return(max_id);
}

void print_arpa_unigram(FILE *outfile, float *p, int vocab_size) {
    static char rname[] = "print_arpa_unigram";
    int i;
    char *word;
    double log10_p, total;
    static double log_10;

    log_10 = log(10.0);

    fprintf(outfile, "\\data\\\n");
    fprintf(outfile, "ngram 1=%d\n\n", vocab_size+1);
    fprintf(outfile, "\\1-grams:\n");

    total = 0.0;
    for(i = 0; i <= vocab_size; i++) {
        word = bow_int2word(i);

        if (p[i] == 0) {
            log10_p = -99.0;
        }
        else {
            log10_p = log(p[i]) / log_10;
            total += p[i];
        }

        fprintf(outfile, "%.4f %s\n", log10_p, word);
    }
    fprintf(outfile, "\n\\end\\\n");

    if (fabs(total - 1.0) > 0.000001)
        quit(-1, "%s: output distribution sums to %f  - not 1.0...\n", rname, total);

}

void print_abbrev_arpa_unigram(FILE *outfile, char *vocab_file, double *p, int vocab_size, double zeroton_lprob) {
    static char rname[] = "print_abbrev_arpa_unigram";
    int i, num_types;
    char *word;
    double log10_p, total_prob, zeroton_log10prob, zeroton_prob;
    static double log_10;

    log_10 = log(10.0);
    num_types = 0;

    if (zeroton_lprob < 0) { /* hence, a valid zeroton value was passed in */
        zeroton_log10prob = zeroton_lprob / log_10;
        zeroton_prob = exp(zeroton_lprob);

        /* We compare each value in the probability array being passed to the value of the zeroton prob,
           and we print the word and prob only if the value is different than the zeroton prob */

        for (i = 0; i <= vocab_size; i++) {
            if (fabs(log(p[i]) - zeroton_lprob) > 0.0001) num_types++;
        }

        fprintf(outfile, "# Abbreviated arpa unigram format:\n");
        fprintf(outfile, "# Only words with probs different than the zeroton prob are printed.\n\n");
        fprintf(outfile, "# Corresponding vocab = %s\n", vocab_file);
        fprintf(outfile, "# zeroton log10_prob = %f\n\n", zeroton_log10prob);
    }
    else {
        for (i = 0; i <= vocab_size; i++) {
            if (p[i] > 0.0) num_types++;
        }
    }

    fprintf(outfile, "\\data\\\n");
    fprintf(outfile, "ngram 1=%d\n\n", num_types);
    fprintf(outfile, "\\1-grams:\n");

    if (zeroton_lprob > 0) { /* no zeroton probs to deal with */
        total_prob = 0.0;
        for (i = 0; i <= vocab_size; i++) {
            if (p[i] > 0.0) {
                word = bow_int2word(i);
                log10_p = log(p[i]) / log_10;
                total_prob += p[i];
            
                fprintf(outfile, "%.4f %s\n", log10_p, word);
            }
        }
    }
    else {
        total_prob = 0.0;
        for (i = 0; i <= vocab_size; i++) {
            if (fabs(log(p[i]) - zeroton_lprob) > 0.0001) { /* This word has a prob different than the zeroton prob */
                word = bow_int2word(i);
                log10_p = log(p[i]) / log_10;
                total_prob += p[i];

                fprintf(outfile, "%.7f %s\n", log10_p, word);
            }
            else total_prob += zeroton_prob;
        }
    }
    fprintf(outfile, "\n\\end\\\n");

    if (fabs(total_prob - 1.0) > 0.000001)
        quit(-1, "%s: abbreviated output distribution sums to %.10f  - not 1.0...\n", rname, total_prob);

    if (DEBUG) fprintf(stderr, "printed %d words, %d words were not printed\n", num_types, (vocab_size - num_types+1));

}



