/* basic.c - Basic functions */

#include "general.h"

multinomial *create_distribution(char *label, char *dist_file_path, char *dist_word, int word_count, count_dist *prior_dist, int vocab_size, double obs_uni_alpha, int closed_vocab) {     
    static char rname[]="create_distribution";
    multinomial *dist;
    int word_id, i, max_id; 
    count_dist *prior;

    /* Expects to read distribution EITHER from a file containing 
	a unigram distribution in arpa format, 
        OR to accept a single word as the distribution */

    if (dist_file_path != NULL && dist_word != NULL)
        quit(-1, "%s: in create_distribution, either dist_file_path or dist_word must be NULL\n", rname);

    if (dist_file_path != NULL) {
        /* No vocab size imposed on arpa LMs - use vocab present in model */
        dist = read_arpa_1gram(dist_file_path, closed_vocab);
    }
    else if (dist_word != NULL) {  /* Just create a one-word distribution */
        dist = (multinomial *) kmalloc(sizeof(multinomial));

        /* See if there is a prior for this distribution */
        dist->prior = prior_dist;

        dist->vocab_size = vocab_size;
        dist->num_types = 1;
        dist->num_tokens = word_count;
        dist->uni_alpha = obs_uni_alpha;
        dist->total_count = 0;
        dist->lprobs = NULL;
        dist->prior_weight_adjustment = 1.0;

        word_id = bow_word2int_no_add(dist_word);
        if (word_id == -1 || word_id > vocab_size)
            quit(-1, "%s: error - word %s is not in vocab... assigned id of %d\n", rname, dist_word, word_id);

        dist->counts = (tc *) kmalloc(sizeof(tc));
        dist->counts->id = word_id;
        dist->counts->count = word_count;
        dist->counts->next = NULL;
    }
    else { /* Create a null distribution */
        dist = (multinomial *) kmalloc(sizeof(multinomial));
        dist->lprobs = NULL;
        dist->prior = NULL;
        dist->counts = NULL;
        dist->num_types = 0;
        dist->num_tokens = 0;
        dist->vocab_size = -1;
        dist->uni_alpha = 0;
        dist->total_count = 0;
        dist->prior_weight_adjustment = 1.0;
    }
 
    if (label != NULL) { 
        dist->label = strdup (label);
    }
    else dist->label = strdup ("none");

    if (dist_file_path != NULL) {
        dist->path = strdup (dist_file_path);
    }
    else dist->path = strdup ("none");

    return(dist);
}

int get_dist_index(multinomial **dists, char *label, int num_dists) {
    int i;

    for (i = 0; i < num_dists; i++) {
        if (!strcmp(dists[i]->label, label)) return(i);
    }
    return(-1);

}

int get_state_index(state **states, int id, int num_states) {
    int i;

    for (i = 0; i < num_states; i++) {
        if (states[i]->id == id) return(i);
    }
    return(-1);

}

state *create_state(int id, char *label, multinomial *dist, int duration) {
    state *new_state;

    /* Allocate space for the new state */
    new_state = (state *) kmalloc(sizeof(state));

    /* Initialize state fields */
    new_state->id = id;
    new_state->seen = 0;
    new_state->label = strdup(label);
    new_state->lprior = 0.0;
    new_state->llikelihood = 0.0;
    new_state->trans_prior_weight_adjustment = 1.0;
    new_state->O = dist;
    new_state->duration = duration;
    new_state->in = NULL;
    new_state->out = NULL;

    return(new_state);
}

void add_trans(state *from_state, state *to_state, int count, double trans_uni_alpha, double prob) {
    static char rname[]="add_trans";
    trans *new_trans, *current;

    if (DEBUG) fprintf(stderr, "Adding a transition between %d and %d\n", from_state->id, to_state->id);
 
    /* Run a check that a transition between these two states doesn't already exist */
    current = from_state->out;
    while (current != NULL) {
        if (current->source == from_state && current->dest == to_state) {
            if (DEBUG) fprintf(stderr, "WARNING: Adding a transition between two states where a transition already exists -- incrementing counts!\n");
            current->count += count;
            return;
        }
        current = current->next_source;
    }
   
    /* Allocate space for the new transition */ 
    new_trans = (trans *) kmalloc(sizeof(trans));

    /* Assign basic values */
    new_trans->source = from_state;
    new_trans->dest = to_state;
    new_trans->count = count;
    new_trans->prob = prob;
    new_trans->alpha = trans_uni_alpha;
    new_trans->next_source = NULL;
    new_trans->next_dest = NULL;

    /* Add new transition to the linked lists of the source state */
    current = from_state->out;
    if (current == NULL) {
        from_state->out = new_trans;
        new_trans->prev_source = NULL;
    }  
    else { 
        while (current->next_source != NULL) {
            current = current->next_source;
        }
        current->next_source = new_trans;
        new_trans->prev_source = current;
    }

    /* Add new transition to the linked lists of the destination state */
    current = to_state->in;
    if (current == NULL) {
        to_state->in = new_trans;
        new_trans->prev_dest = NULL;
    }
    else {
        while (current->next_dest != NULL) {
            current = current->next_dest;
        }
        current->next_dest = new_trans;
        new_trans->prev_dest = current;
    }
}

void remove_trans(state *from_state, state *to_state) {
    static char rname[]="remove_trans";
    trans *to_remove, *current_t, *next, *prev;
    int found;

    found = 0;

    if (DEBUG) fprintf(stderr, "Removing transition between %d and %d\n", from_state->id, to_state->id);

    /* Find the transition between the given states */
    current_t = from_state->out;
    while (current_t != NULL) {
        if (current_t->source == from_state && current_t->dest == to_state) {
           found = 1;
           to_remove = current_t;
           break;
        }
        current_t = current_t->next_source;
    }

    if (!found) quit(-1, "%s: did not find transition to remove...\n", rname);

    /* Remove transition from source state linked list */
    next = to_remove->next_source;
    prev = to_remove->prev_source;
    if ((next == NULL) && (prev == NULL)) {
        from_state->out = NULL;
    }
    else if ((next == NULL) && (prev != NULL)) {
        prev->next_source = NULL;
    }
    else if ((next != NULL) && (prev == NULL)) {
        next->prev_source = NULL;
        from_state->out = next;
    }
    else if ((next != NULL) && (prev != NULL)) {
        next->prev_source = prev;
        prev->next_source = next;
    }
    else quit(-1, "%s: ERROR...\n", rname);

    /* Remove transition from destination state linked list */
    next = to_remove->next_dest;
    prev = to_remove->prev_dest;

    if ((next == NULL) && (prev == NULL)) {
        to_state->in = NULL;
    }
    else if ((next == NULL) && (prev != NULL)) {
        prev->next_dest = NULL;
    }
    else if ((next != NULL) && (prev == NULL)) {
        next->prev_dest = NULL;
        to_state->in = next;
    }
    else if ((next != NULL) && (prev != NULL)) {
        next->prev_dest = prev;
        prev->next_dest = next;
    }
    else quit(-1, "%s: ERROR...\n", rname);

    /* Free space taken by transition */
    free(to_remove);
}

multinomial *combine_emissions(multinomial *O1, multinomial *O2, int id) {
    static char rname[]="combine_emissions";
    multinomial *new;
    int i, found;
    count_dist *prior;
    tc *temp, *old, *new_count, *prev;
    char temp_label[100];

    new = (multinomial *) kmalloc(sizeof(multinomial));

    /* Use label from new state */
    sprintf(temp_label, "d%d", id);
    new->label = strdup(temp_label);

    if (!strcmp(O1->path,O2->path)) {
        new->path = strdup (O1->path);
    }
    else new->path = strdup ("none");

    new->num_types = 0;

    if (O1->num_tokens == -1 || O2->num_tokens == -1)
        quit(-1, "%s: merging a distribution that has logprob values instead of counts...\n", rname);

    new->num_tokens = 0;

    if (O1->vocab_size != O2->vocab_size) quit(-1, "%s: vocab sizes are not equal!\n", rname);
    else new->vocab_size = O1->vocab_size;

    if (O1->uni_alpha != O2->uni_alpha) quit(-1, "%s: uniform alpha values are not equal!\n", rname);
    else new->uni_alpha = O2->uni_alpha;

    new->total_count = 0;
    new->lprobs = NULL;

    if (O1->prior == O2->prior) new->prior = O1->prior;
    else new->prior = NULL;

    new->counts = NULL;
    prev = NULL;
    old = O1->counts;
    while (old != NULL) {

        found = 0;
        temp = new->counts; 
        while (temp != NULL) {
            if (temp->id == old->id) {
                temp->count += old->count;
                found = 1;
                break;
            }
            temp = temp->next;
        }

        if (!found) {
            new_count = (tc *) kmalloc(sizeof(tc));
            new_count->id = old->id;
            new_count->count = old->count;
            new_count->next = NULL;
            if (new->counts == NULL) new->counts = new_count;
            else prev->next = new_count;
            prev = new_count;
        }
        old = old->next;
    }

    old = O2->counts;
    while (old != NULL) {

        found = 0;
        temp = new->counts;
        while (temp != NULL) {
            if (temp->id == old->id) {
                temp->count += old->count;
                found = 1;
                break;
            }
            temp = temp->next;
        }

        if (!found) {
            new_count = (tc *) kmalloc(sizeof(tc));
            new_count->id = old->id;
            new_count->count = old->count;
            new_count->next = NULL;
            if (new->counts == NULL) new->counts = new_count;
            else prev->next = new_count;
            prev = new_count;
        }
        old = old->next;
    }

    temp = new->counts;
    while (temp != NULL) {
        new->num_types++;
        new->num_tokens += temp->count;
        temp = temp->next;
    }

    new->prior_weight_adjustment = 1.0;

    return(new);
}

void combine_trans(state *new, state *s1, state *s2, trans **p_trans_list, double trans_uni_alpha) {
    static char rname[]="combine_trans";
    trans *current_t, *trans_list, *next_t, *temp, *prev;
    double prob;

    /* We'll keep a list of the orignial transitions involving states s1 and s2
	that we can use if we need to undo this merge */

    /* Process input transitions */
    prev = NULL;
    current_t = s1->in;
    while (current_t != NULL) {

        if (DEBUG) fprintf(stderr, "1: Transition exists between %d and %d\n", current_t->source->id, current_t->dest->id);

        if (current_t->dest != s1)
            quit(-1, "%s: transition dest %d does not equal current dest state %d...\n", rname, current_t->dest->id, s1->id);

        prob = 0.0;
        if (current_t->source == s1 || current_t->source == s2) {
            add_trans(new, new, current_t->count, trans_uni_alpha, prob);
        }
        else {
            add_trans(current_t->source, new, current_t->count, trans_uni_alpha, prob);
        }

        temp = (trans *) kmalloc(sizeof(trans));
        temp->source = current_t->source;
        temp->dest = s1;
        temp->count = current_t->count;
        temp->prob = current_t->prob;
        temp->alpha = current_t->alpha;
        temp->next_source = NULL;
        if (*p_trans_list == NULL) *p_trans_list = temp;
        if (prev != NULL) prev->next_source = temp;
        prev = temp;

        next_t = current_t->next_dest;
        remove_trans(current_t->source, s1);
        current_t = next_t;
    }

    current_t = s2->in;
    while (current_t != NULL) {

        if (DEBUG) fprintf(stderr, "2: Transition exists between %d and %d\n", current_t->source->id, current_t->dest->id);
        if (current_t->dest != s2)
            quit(-1, "%s: transition dest %d does not equal current dest state %d...\n", rname, current_t->dest->id, s2->id);

        prob = 0.0;
        if (current_t->source == s1 || current_t->source == s2) {
            add_trans(new, new, current_t->count, trans_uni_alpha, prob);
        }
        else {
            add_trans(current_t->source, new, current_t->count, trans_uni_alpha, prob);
        }

        temp = (trans *) kmalloc(sizeof(trans));
        temp->source = current_t->source;
        temp->dest = s2;
        temp->count = current_t->count;
        temp->prob = current_t->prob;
        temp->alpha = current_t->alpha;
        temp->next_source = NULL;
        if (*p_trans_list == NULL) *p_trans_list = temp;
        if (prev != NULL) prev->next_source = temp;
        prev = temp;

        next_t = current_t->next_dest;
        remove_trans(current_t->source, s2);
        current_t = next_t;
    }

    current_t = s1->out;
    while (current_t != NULL) {

        if (DEBUG) fprintf(stderr, "3: Transition exists between %d and %d\n", current_t->source->id, current_t->dest->id);

        if (current_t->source != s1)
            quit(-1, "%s: transition source %d does not equal current source state %d...\n", rname, current_t->source->id, s1->id);

        prob = 0.0;
        if (current_t->dest == s1 || current_t->dest == s2) {
            add_trans(new, new, current_t->count, trans_uni_alpha, prob);
        }
        else {
            add_trans(new, current_t->dest, current_t->count, trans_uni_alpha, prob);
        }

        temp = (trans *) kmalloc(sizeof(trans));
        temp->source = s1;
        temp->dest = current_t->dest;
        temp->count = current_t->count;
        temp->prob = current_t->prob;
        temp->alpha = current_t->alpha;
        temp->next_source = NULL;
        if (*p_trans_list == NULL) *p_trans_list = temp;
        if (prev != NULL) prev->next_source = temp;
        prev = temp;

        next_t = current_t->next_source;
        remove_trans(s1, current_t->dest);
        current_t = next_t;
    }

    current_t = s2->out;
    while (current_t != NULL) {

        if (DEBUG) fprintf(stderr, "4: Transition exists between %d and %d\n", current_t->source->id, current_t->dest->id);

        if (current_t->source != s2)
            quit(-1, "%s: transition source %d does not equal current source state %d...\n", rname, current_t->source->id, s2->id);

        prob = 0.0;
        if (current_t->dest == s1 || current_t->dest == s2) {
            add_trans(new, new, current_t->count, trans_uni_alpha, prob);
        }
        else {
            add_trans(new, current_t->dest, current_t->count, trans_uni_alpha, prob);
        }

        temp = (trans *) kmalloc(sizeof(trans));
        temp->source = s2;
        temp->dest = current_t->dest;
        temp->count = current_t->count;
        temp->prob = current_t->prob;
        temp->alpha = current_t->alpha;
        temp->next_source = NULL;
        if (*p_trans_list == NULL) *p_trans_list = temp;
        if (prev != NULL) prev->next_source = temp;
        prev = temp;

        next_t = current_t->next_source;
        remove_trans(s2, current_t->dest);
        current_t = next_t;
    }

    return;
}

void free_state(state *s1) {
    static char rname[]="free_state";
    trans *current_t, *next_t;
    tc *temp, *next;

    if (s1->O != NULL) {
        free(s1->O->label);
        free(s1->O->path);
        if (s1->O->lprobs != NULL) free(s1->O->lprobs);
        temp = s1->O->counts;
        while (temp != NULL) {
            next = temp->next;
            free(temp);
            temp = next;
        }
        free(s1->O);
    }

    current_t = s1->in;
    while (current_t != NULL) {
        next_t = current_t->next_dest;
        remove_trans(current_t->source, current_t->dest);
        current_t = next_t;
    }

    current_t = s1->out;
    while (current_t != NULL) {
        next_t = current_t->next_source;
        remove_trans(current_t->source, current_t->dest);
        current_t = next_t;
    }

    free(s1->label);
    free(s1);
}

double calc_emis_alpha(multinomial *O, int id, double prior_weight_adjustment) {
    static char rname[]="calc_emis_alpha";
    count_dist *prior;
    double alpha;

    alpha = O->uni_alpha;

    if (O->prior != NULL) {
        alpha += O->prior->alphas[id]; 
    }
 
    alpha /= prior_weight_adjustment;

    return(alpha);
}

double calc_trans_alpha(state *s1, double alpha) {
    static char rname[]="calc_trans_alpha";

    alpha /= s1->trans_prior_weight_adjustment;

    return(alpha);
}


int read_vocab(char *vocab_file_path, int closed_vocab) {
    static char rname[]="read_vocab";
    FILE *vocab_file;
    int max_id, word_id;
    char *word, line[3000];

    vocab_file = kopen_r(vocab_file_path);

    max_id = -1;

    while (fgets(line, 3000, vocab_file) != NULL) {
        word = strtok(line, " \n");
        if (word[0] == '#') continue;

        while (word != NULL) {
            word_id = bow_word2int(word);
            if (max_id < word_id) max_id = word_id;

            if (DEBUG) fprintf(stderr, "word = %s, id = %d\n", word, word_id);

            word = strtok(NULL, " \n");
        }
    }
    fclose(vocab_file);

    /* Add the unknown word to the vocab if it isn't already there */
    if (!closed_vocab) {
        word_id = bow_word2int(unk_word);
        if (max_id < word_id) max_id = word_id;
    }

    /* Vocab size is inclusive because we return max_id. If we returned
       max_id+1 then it would not be inclusive */

    return(max_id);
}

double retrieve_count(multinomial *O, int id) {
    static char rname[]="retrieve_count";
    tc *current;

    current = O->counts;
    while (current != NULL) {
        if (current->id == id) return(current->count);
        current = current->next;
    } 
    return(0.0);
}

float calc_lprob(multinomial *O, int id, int mode) {
    static char rname[]="calc_lprob";
    double count, numerator;
    float lprob;

    if (DEBUG) fprintf(stderr, " In %s, calling get_count\n", rname);
    numerator = get_count(O, id, mode);

    if (numerator == 0) return(-99.0);
    else lprob = log(numerator) - log(O->total_count);

    return(lprob);
}

double get_count(multinomial *O, int id, int mode) {
    static char rname[]="get_count";
    double alpha, complete_count, count;
    float lprob;

    if (O->total_count <= 0)
        quit(-1, "%s: total counts have not been collected...\n", rname);

    count = retrieve_count(O, id);
    alpha = calc_emis_alpha(O, id, O->prior_weight_adjustment);

    if (DEBUG) fprintf(stderr, "count = %f, alpha = %f\n", count, alpha);

    if (mode == 1) {
        complete_count = count + alpha - 1;
    }
    if (mode == 2) {
        complete_count = count + alpha;
    }
    if (mode == 3 || mode == 4) {
        complete_count = count;
    }

    if (complete_count < 0) quit(-1, "%s: error - count of %f for word id %d in multinomial %s\n", rname, complete_count, id, O->label);

    return(complete_count);
}


char *kmalloc(size_t size) {
    static char rname[]="kmalloc";
    char *mem;

    mem = (char *) malloc(size);
    if (!mem)
        quit(-1, "%s: malloc failed - not enough memory...\n", rname);

    return(mem);
}

FILE *kopen_r(char *file_path) {
    static char rname[]="kopen_r";
    FILE *file;
    int lpath;
    struct stat file_stat;
    char pipe[256];

    lpath = strlen(file_path);
    if (strcmp(&file_path[lpath-2],".Z") == 0) {
        if (stat(file_path, &file_stat) != 0) quit(-1,"%s: file '%s' not found\n",rname, file_path);
        sprintf(pipe,"zcat %s", file_path);
        file = popen(pipe,"r");
    }
    else if (strcmp(&file_path[lpath-3],".gz") == 0) {
        if (stat(file_path, &file_stat) != 0) quit(-1,"%s: file '%s' not found\n",rname, file_path);
        sprintf(pipe,"cat %s | gunzip", file_path);
        file = popen(pipe,"r");
    }
    else if ((file = fopen(file_path,"r")) == NULL)
        quit(-1, "%s: could not open file %s\n", rname, file_path);

    return(file);
}

FILE *kopen_w(char *file_path) {
    static char rname[]="kopen_w";
    FILE *file;

    if ((file = fopen(file_path,"w")) == NULL)
       quit(-1, "%s: could not open file %s\n", rname, file_path);

    return(file);
}

FILE *kopen_wgz(char *file_path) {
    static char rname[]="kopen_wgz";
    FILE *file;
    char pipe[256];

    sprintf(pipe,"gzip > %s", file_path);
    if ((file = popen(pipe,"w")) == NULL)
       quit(-1, "%s: could not open file %s\n", rname, file_path);

    return(file);
}

FILE *kopen_a(char *file_path) {
    static char rname[]="kopen_a";
    FILE *file;

    if ((file = fopen(file_path,"a")) == NULL)
       quit(-1, "%s: could not open file %s\n", rname, file_path);

    return(file);
}

/* Borrowed from version 1 of the SLM toolklit */
int quit(int rc, char *msg, ...) {
   va_list args ;
 
   va_start(args,msg) ;
   vfprintf(stderr,msg,args) ;
   va_end(msg) ;
   exit(rc) ;
}



