/* socket.c */

#include "general.h"

static int rainbow_sockfd;

void rainbow_socket_init (const char *socket_name, int use_unix_socket) {
    static char rname[]="rainbow_socket_init";
    int servlen, type, bind_ret;
    struct sockaddr_un un_addr;
    struct sockaddr_in in_addr;
    struct sockaddr *sap;

    type = use_unix_socket ? AF_UNIX : AF_INET;
   
    rainbow_sockfd = socket(type, SOCK_STREAM, 0);
    assert(rainbow_sockfd >= 0);

    if (type == AF_UNIX) {
        sap = (struct sockaddr *)&un_addr;
        bzero((char *)sap, sizeof(un_addr));
        strcpy(un_addr.sun_path, socket_name);
        servlen = strlen(un_addr.sun_path) + sizeof(un_addr.sun_family) + 1;
    }
    else {
        sap = (struct sockaddr *)&in_addr;
        bzero((char *)sap, sizeof(in_addr));
        in_addr.sin_port = htons(atoi(socket_name));
        in_addr.sin_addr.s_addr = htonl(INADDR_ANY);
        servlen = sizeof(in_addr);
    }

    sap->sa_family = type;     

    bind_ret = bind(rainbow_sockfd, sap, servlen);
    assert(bind_ret >= 0);

    listen(rainbow_sockfd, 5);
}


void rainbow_serve (state *hmm, int obs_with_id, int run_vit, int print_probs, int run_forward, int vit_details, int forward_details, int punc_trans, int max_class, state **states, int num_states, int read_label, float trans_weight, int print_state_id) {
    static char rname[]="rainbow_serve";
    int newsockfd, clilen, stop;
    struct sockaddr cli_addr;
    FILE *in, *out;

    clilen = sizeof(cli_addr);
    newsockfd = accept(rainbow_sockfd, &cli_addr, &clilen);

    assert(newsockfd >= 0);

    in = fdopen(newsockfd, "r");
    out = fdopen(newsockfd, "w");

    stop = 0;
    while (!stop) { 
        stop = hmm_query(in, out, hmm, obs_with_id, run_vit, print_probs, run_forward, vit_details, forward_details, punc_trans, max_class, states, num_states, read_label, trans_weight, print_state_id);
    } 

    fclose(in);
    fclose(out);

    close(newsockfd);
}

int hmm_query(FILE *in, FILE *out, state *hmm, int obs_with_id, int run_vit, int print_probs, int run_forward, int vit_details, int forward_details, int punc_trans, int max_class, state **states, int num_states, int read_label, float trans_weight, int print_state_id) {
    static char rname[]="hmm_query";
    shead       *start;
    path_head   *vit_path, *max_path;
    double      fprob, elprob, tlprob;

    elprob = 0.0;
    tlprob = 0.0;

    start = read_one_string(in, obs_with_id, read_label, 0);

    if (start == NULL) return (1);
    else {

        /* Find max class for each observation in string */
        if (max_class) {
            max_path = find_max_path(start, states, num_states, read_label);
            print_path_to_file(max_path, out, start, obs_with_id, print_state_id, print_probs);
            free_path(max_path);
        }

        /* Find Viterbi path for string */
        if (run_vit) {
            vit_path = find_vit_path(start, hmm, vit_details, punc_trans, read_label, trans_weight);
            print_path_to_file(vit_path, out, start, obs_with_id, print_state_id, print_probs);
            free_path(vit_path);
        }

        /* Calculate forward path probability */
        if (run_forward) {
            fprob = calc_forward_prob(start, hmm, forward_details, punc_trans, read_label, trans_weight, &elprob, &tlprob);
            print_forward_prob_to_file(fprob, out, start, obs_with_id);
        }

        free_string(start);
        fflush(out);
    }
    return (0);
}



