/* smooth.c - routines for smoothing unigram distributions */

#include "general.h"


double *maximum_likelihood(multinomial *O, int vocab_size, int mode, double *ptr_zeroton_lprob) {
    int i;
    double lprob, *prob;

    if (DEBUG) fprintf(stderr, "Calculating maximum likelihood estimates...\n\n");

    prob = (double *) kmalloc((vocab_size+1)*sizeof(double));

    for (i = 0; i <= vocab_size; i++) {
        lprob = calc_lprob(O, i, mode);
        if (lprob < -98.0)
            prob[i] = 0.0;
        else
            prob[i] = exp(lprob);
    }
    *ptr_zeroton_lprob = no_prob;

    return(prob);
}


double *absolute_discounting(multinomial *O, int vocab_size, int mode, double *ptr_zeroton_lprob) {
    static char rname[]="absolute_discounting";
    double discount, discount_factor, total_prob, uniform_prob;
    double zeroton_prob, zeroton_lprob, *prob, count;
    int i, num_tokens, num_types, fof[3];

    fprintf(stderr, "%s: multinomial label = %s\n", rname, O->label);

    /* Absolute discounting (version where smoothed with a uniform distribution) */
    zeroton_prob = 0.0;
    zeroton_lprob = 0.0;
    uniform_prob = 1.0 / (vocab_size + 1);

    /* Initialize freq-of-freq values */
    for (i = 0; i <= 2; i++) {
        fof[i] = 0;
    }
    num_tokens = 0;
    num_types = 0;

    /* Count freq-of-freq stats */
    for (i = 0; i <= vocab_size; i++) {
        count = get_count(O, i, mode);
        if (count > 0) {
            num_types++;

            /* This is an approximate freq-of-freq count */
            if (count < 2)  fof[1]++;
            else if (count < 3)  fof[2]++;
            num_tokens += count;
        }
    }
    fof[0] = (vocab_size+1) - num_types;

    if (fof[0] == 0) 
        fprintf(stderr, "%s: No zerotons... \n", rname);

    if (fof[1] == 0) {
        fprintf(stderr, "%s: No singletons... using fixed discount of 0.01\n", rname);
        discount_factor = 0.01;
        discount = (discount_factor * num_tokens) / num_types;
    }
    else {
        discount =  ((double) fof[1]) / (fof[1] + 2 * fof[2]);
        discount_factor = discount * num_types / num_tokens;
    }

    fprintf(stderr, "%s: discount = %f\n", rname, discount);

    if (discount > 0.9) {
        discount = 0.9;
        discount_factor = discount * num_types / num_tokens;
        fprintf(stderr, "%s: warning - Absolute discount too high, limit of 0.9 enforced... distribution = %s\n", rname, O->label);
    }
    fprintf(stderr, "%s: Subtracting %f from all counts (%d / %d + 2 * %d)\n", rname, discount, fof[1], fof[1], fof[2]);
    fprintf(stderr, "%s: discount = %f, discount factor = %f\n", rname, discount, discount_factor);

    zeroton_prob = discount_factor * uniform_prob;
    zeroton_lprob = log(zeroton_prob);

    prob = (double *) kmalloc((vocab_size+1)*sizeof(double));
    total_prob = 0.0;
    for (i = 0; i <= vocab_size; i++) {
        count = get_count(O, i, mode);

        if (count > discount) {
            prob[i] = ((count - discount) / num_tokens) + discount_factor * uniform_prob;
        }
        else {
            prob[i] = zeroton_prob;
        }
        total_prob += prob[i];
    }

    if (DEBUG) 
        for (i = 0; i <= vocab_size; i++) {
            fprintf(stderr, "word = %d, prob = %f\n", i, prob[i]);
        }

    if (fabs(1.0 - total_prob) > 0.0001) fprintf(stderr, "%s: warning - total_prob = %f\n", rname, total_prob);

    fprintf(stderr, "%s: Total prob = %f, num zerotons = %d, vocab_size = %d\n\n", rname, total_prob, fof[0], vocab_size+1);

    *ptr_zeroton_lprob = zeroton_lprob;

    return(prob);
}


double *three_way_linear_interpolation(multinomial *O, int vocab_size, int mode, double *ptr_zeroton_lprob) {
    static char rname[]="three_way_linear_interpolation";
    int model, word, num_dists;
    double *prob, **mixture_counts, *lambdas;
    double temp_count, temp_alpha;

    /* Linear interpolation - interpolate observed counts, prior counts, and a uniform distribution */

    fprintf(stderr, "Linear interpolation for emission distribution %s ", O->label);
    if (O->prior != NULL) fprintf(stderr, "(prior %s):\n", O->prior->label);
    else fprintf(stderr, "\n");

    num_dists = 3;
    mixture_counts = (double **) kmalloc(num_dists * sizeof(double *));
    for (model = 0; model < num_dists; model++) {
        mixture_counts[model] = (double *) kmalloc((vocab_size+1) * sizeof(double));
    }

    for (word = 0; word <= vocab_size; word++) {

        /* Copy in counts from multinomial (model 0) */
        temp_count = retrieve_count(O, word);
        mixture_counts[0][word] = temp_count;

        /* Copy in counts from prior (alpha counts) (model 1) */
        temp_alpha = calc_emis_alpha(O, word, O->prior_weight_adjustment);
        mixture_counts[1][word] = temp_alpha;

        /* Set up uniform counts (model 2) */
        mixture_counts[2][word] = 1.0;
    }

    lambdas = loo_linear_interpolation(mixture_counts, num_dists, vocab_size);

    fprintf(stderr, "Final interpolation weights: ");
    for (model = 0; model < num_dists; model++) {
        fprintf(stderr, "model %d = %.6f, ", model, lambdas[model]);
    }
    fprintf(stderr, "\n");

    prob = combine_mixture_model(mixture_counts, num_dists, vocab_size, lambdas);
    free(lambdas);

    for (model = 0; model < num_dists; model++) {
        free(mixture_counts[model]);
    }
    free(mixture_counts);

    return(prob);
}


double *loo_linear_interpolation(double **counts, int num_dists, int vocab_size) {
    static char rname[]="loo_linear_interpolation";
    int model, word, iter_num;
    double new_pp, old_pp, stop_ratio, total_prob, sum_logprob;
    double *lambda, *fractions, *prob_component, prob, *total_count;
    double num_parent_counts, numerator_count, denominator_count, num_counts;

    /* We assume that counts contains num_dists count distributions.
	These count distributions will be interpolated together with
	mixture weights as determined by leave-one-out expectation-maximization
	over the counts in the first distribution (counts[0]).
	The mixture weights are returned by the routine. */

    stop_ratio = 0.999;

    /* Count number of tokens in each count distribution */
    total_count = (double *) kmalloc(num_dists*sizeof(double));
    for (model = 0; model < num_dists; model++) {
        total_count[model] = 0;
        for (word = 0; word <= vocab_size; word++) {
            total_count[model] += counts[model][word];
        }
    }
    if (total_count[0] == 0) quit(-1, "%s: error - leaf distribution has no counts!\n", rname);

    /* Initialize the weights (lambdas) */
    lambda = (double *) kmalloc(num_dists*sizeof(double));
    for (model = 0; model < num_dists; model++) {
        lambda[model] = 1.0 / num_dists;
    }

    /* Return if leaf only contains one count */ 
    if (fabs(1.0 - total_count[0]) < 0.0001)  {
        free(total_count);
        return(lambda);
    }

    /* Initialize array to hold partial probabilities */
    prob_component = (double *) kmalloc(num_dists*sizeof(double));

    /* Initialize array to hold fractional counts */
    fractions = (double *) kmalloc(num_dists*sizeof(double));

    /* TRAINING: iterate EM */
    new_pp = 10e98;
    iter_num = 1;
    while (iter_num == 1 || (new_pp/old_pp < stop_ratio)) {
        old_pp = new_pp;

        /* M-step:  */
        /* Re-estimate lambdas before all but the first iteration */
        if (iter_num > 1) {
            for (model = 0; model < num_dists; model++) {
                lambda[model] = fractions[model] / num_counts;
            }
        }

        /* E-step: */
        sum_logprob = 0;
        num_counts = total_count[0];

        /* Reset fractional counts */
        for (model = 0; model < num_dists; model++) {
            fractions[model] = 0;
        }

        for (word = 0; word <= vocab_size; word++) {

            if (counts[0][word] < 1) continue;

            if (counts[0][word] <= 1) {

                /* If there are no counts in the parent distributions, then there's
			no data for LOO estimation */
                num_parent_counts = 0;
                for (model = 1; model < num_dists; model++) {
                    num_parent_counts += counts[model][word];
                }
                if (num_parent_counts == 0) {
                    num_counts -= 1;
                    continue;
                }
            }

            /* For each model, estimate the probability of this word.
		The leaf distribution is estimated using LOO */

            total_prob = 0;
            for (model = 0; model < num_dists; model++) {

                numerator_count = counts[model][word];
                if (model == 0) numerator_count--;

                denominator_count = total_count[model];
                if (model == 0) denominator_count--;

                if (denominator_count == 0) prob = 0;
                else prob = numerator_count / denominator_count;

                prob_component[model] = lambda[model] * prob;
                total_prob += prob_component[model];
            }

            for (model = 0; model < num_dists; model++) {
                fractions[model] += (counts[0][word] * prob_component[model] / total_prob);
            }

            if (DEBUG) fprintf(stderr, "total prob = %f\n", total_prob);

            sum_logprob += (counts[0][word] * log(total_prob));
            if (DEBUG) fprintf(stderr, "sum logprob = %f\n", sum_logprob);
        }

        fprintf(stderr, "Lambdas: ");
        for (model = 0; model < num_dists; model++) {
            fprintf(stderr, "%.6f ", lambda[model]);
        }
        if (num_counts == 0) break;

        new_pp = exp(-sum_logprob/num_counts);
        fprintf(stderr, "\tPP = %.3f (%f items)\n", new_pp, num_counts);
        iter_num++;
    }

    free(total_count);
    free(prob_component);
    free(fractions);

    return(lambda);
}


double *combine_mixture_model(double **counts, int num_dists, int vocab_size, double *lambda) {
    static char rname[]="combine_mixture_model";
    int model, word;
    double temp_prob, *prob, *total_count, total_prob;

    /* Count number of tokens in each count distribution */
    total_count = (double *) kmalloc(num_dists*sizeof(double));
    for (model = 0; model < num_dists; model++) {
        total_count[model] = 0;
        for (word = 0; word <= vocab_size; word++) {
            total_count[model] += counts[model][word];
        }
    }
    if (total_count[0] == 0) quit(-1, "%s: error - leaf distribution has no counts!\n", rname);

    if (DEBUG) 
        for (model = 0; model < num_dists; model++) {
            fprintf(stderr, "total count for model %d = %f\n", model, total_count[model]);
        }

    /* Initialize prob vector */
    prob = (double *) kmalloc((vocab_size+1)*sizeof(double));

    total_prob = 0.0;
    for (word = 0; word <= vocab_size; word++) {
        temp_prob = 0.0;
        for (model = 0; model < num_dists; model++) {
            temp_prob += lambda[model] * counts[model][word] / total_count[model];
            if (DEBUG) fprintf(stderr, "model = %d, word = %d, prob = %f (%f * %f / %f)\n", model, word, temp_prob, lambda[model], counts[model][word], total_count[model]);
        }
        if (DEBUG) fprintf(stderr, "word = %d, temp prob = %f\n", word, temp_prob);
        prob[word] = temp_prob;
        total_prob += temp_prob;
    }

    if (fabs(total_prob - 1.0) > 0.00001) quit(-1, "%s: error - total prob of interpolated distributions equals %f, not 1.0\n", total_prob);

    free(total_count);
    return(prob);
}


