#include <stdio.h>
#include "mex.h"
#include "svm.h"
#include <string.h>
#include "svm_model_matlab.h"
#include "IntersectionClassifier.h"
#include <time.h>

#if MX_API_VER < 0x07030000
typedef int mwIndex;
#endif 

#define CMD_LEN 2048
void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
{
	int i, j, low, high;
	mwIndex *ir, *jc;
	double *samples;

	ir = mxGetIr(prhs);
	jc = mxGetJc(prhs);
	samples = mxGetPr(prhs);

	// each column is one instance
	j = 0;
	low = jc[index], high = jc[index+1];
	for(i=low;i<high;i++)
	{
		x[j].index = ir[i] + 1;
		x[j].value = samples[i];
		j++;
 	}
	x[j].index = -1;
}

void exit_with_help()
{
	mexPrintf(
	" Usage:" 
        " [exact_values, pwconst_values, pwlinear_values,[times]] = ...\n"
	" \tfastpredict(testing_label_vector, testing_instance_matrix, model,'libsvm_options')\n"
	" \n"
	" Output:\n"
	"   exact_values    : predictions using binary search\n"
	"   pwconst_values  : approximation using piecewise constant function\n"
	"   pwlinear_values : approximation using piecewise linear function\n"
	"   [times]         : running times \n"
	" \n"
	" \n"
	" libsvm_options:\n"
	"   -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0);\n"
	"   -v verbose flag         : 0 or 1 (default 0);\n"
	"   -n number of bins       : [2,...] (default 100);\n"
	"\n"
	" (Note: Only SVC and 2 class classifier is supported)\n"
	" \n"
	);
}

static void fake_answer(mxArray *plhs[])
{
	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
	plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
	plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
}

/** 
 * Modified by subhransu maji (smaji@cs.berkeley.edu) to output the dec_values/prob_estimates 
 * using 3 different techniques.
 * exact compution using binary search 
 * piecewise constant approximation
 * piecewise linear approximation
 * ***CAVEATS*** Only 2 class and SVC are supported. Data needs to be in dense format
 */
void predict(int nlhs, mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int num_bins, const int predict_probability, const int verbose_flag)

{
  //read the data and get the parameters
  int label_vector_row_num, label_vector_col_num;
  int feature_number, testing_instance_number;
  int instance_index;
  double *ptr_instance; 
  double *ptr_time=NULL;

  double *ptr_prob_estimates0, *ptr_prob_estimates1, *ptr_prob_estimates2; //ouputs
  double *ptr_dec_values0, *ptr_dec_values1, *ptr_dec_values2; //outputs

  double *prob_estimates0=NULL, *prob_estimates1=NULL,*prob_estimates2=NULL; //per entry estimates

  clock_t start,end;
  struct svm_node *x = NULL;

  int svm_type=svm_get_svm_type(model);
  
  if(num_bins < 2 ){
    mexPrintf("-n num_bins should be >= 2 \n");
    fake_answer(plhs);
    return;
  }
  
  if(svm_type != C_SVC){
    mexPrintf("Only SVC Classifier Supported..\n");
    fake_answer(plhs);
    return;
  }
  int nr_class=svm_get_nr_class(model);

  if(nr_class !=2){
    mexPrintf("Only Binary Classifier Supported..\n");
    fake_answer(plhs);
    return;
  }

  // prhs[1] = testing instance matrix
  feature_number = mxGetN(prhs[1]);
  testing_instance_number = mxGetM(prhs[1]);
  label_vector_row_num = mxGetM(prhs[0]);
  label_vector_col_num = mxGetN(prhs[0]);

  if(label_vector_row_num!=testing_instance_number){
    
      mexPrintf("# of labels (# of column in 1st argument) does not match # of instances (# of rows in 2nd argument).\n");
      fake_answer(plhs);
      return;
  }
  if(label_vector_col_num!=1){
    
      mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
      fake_answer(plhs);
      return;
  }
  //the decision values(or prob estimates) exact, pwconst, pwlinear
  int dim2=predict_probability? nr_class:1;

  plhs[0] = mxCreateDoubleMatrix(testing_instance_number, dim2, mxREAL);
  plhs[1] = mxCreateDoubleMatrix(testing_instance_number, dim2, mxREAL);
  plhs[2] = mxCreateDoubleMatrix(testing_instance_number, dim2, mxREAL);

  //place holder for time
  if(nlhs == 4) {
    plhs[3] = mxCreateDoubleMatrix(1,4,mxREAL);
    ptr_time = mxGetPr(plhs[3]);
  }
  //the pointer for dec values
  ptr_dec_values0 = mxGetPr(plhs[0]);
  ptr_dec_values1 = mxGetPr(plhs[1]);
  ptr_dec_values2 = mxGetPr(plhs[2]);
  //pointer for the prob values (same)
  ptr_prob_estimates0 = mxGetPr(plhs[0]);
  ptr_prob_estimates1 = mxGetPr(plhs[1]);
  ptr_prob_estimates2 = mxGetPr(plhs[2]);
  
  ptr_instance = mxGetPr(prhs[1]);

  //make a data matrix out of the ptr_dec_values
  double **data = (double**)malloc(testing_instance_number*sizeof(double*));
  if(mxIsSparse(prhs[1])){
    mexPrintf("Sparse Data Format is not supported yet..\n");
    fake_answer(plhs);
    return;
  }

  for(instance_index=0;instance_index<testing_instance_number;instance_index++){
    data[instance_index] = (double *)malloc(feature_number*sizeof(double));
    for(int i=0;i<feature_number;i++){
      data[instance_index][i] = ptr_instance[testing_instance_number*i+instance_index];
    }
  }

  start=clock();
  IntersectionClassifier * intC = new IntersectionClassifier(model,num_bins, feature_number);
  end=clock();
  
  if(nlhs==4)
    ptr_time[0] = (double)(end-start)/CLOCKS_PER_SEC;

  if(verbose_flag){
    printf("======================================\n");
    printf(" Method              \t|\tTime\n");
    printf("======================================\n");
    printf(" precomp             \t|\t"); fflush(stdout);
    mexPrintf("%.4fs\n",(double)(end-start)/CLOCKS_PER_SEC);
  }

  /**
  mexPrintf(" normal   \t|\t"); fflush(stdout);
  start=clock();
  double * normalVals   = intC->predict(model,data,testing_instance_number);
  end=clock();
  mexPrintf("%.4f\n",(double)(end-start)/CLOCKS_PER_SEC);
  */

  start=clock();
  double * linHashVals = intC->linHashPredict(model,data,testing_instance_number);
  end=clock();
  if(nlhs==4)
    ptr_time[1] = (double)(end-start)/CLOCKS_PER_SEC;
  if(verbose_flag){
    mexPrintf(" binary search     \t|\t"); fflush(stdout);
    mexPrintf("%.4fs\n",(double)(end-start)/CLOCKS_PER_SEC);
  } 
  
  start=clock();
  double *pwConstVals = intC->pwConstPredict(model,data,testing_instance_number);
  end=clock();
  if(nlhs==4)
    ptr_time[2] = (double)(end-start)/CLOCKS_PER_SEC;
  if(verbose_flag){
    mexPrintf(" piecewise constant \t|\t"); fflush(stdout);
    mexPrintf("%.4fs\n",(double)(end-start)/CLOCKS_PER_SEC);
  }
  
  start=clock();
  double *pwLinVals = intC->pwLinPredict(model,data,testing_instance_number);
  end=clock();
  if(nlhs==4)
    ptr_time[3] = (double)(end-start)/CLOCKS_PER_SEC;
  if(verbose_flag){
    mexPrintf(" piecewise linear \t|\t"); fflush(stdout);
    mexPrintf("%.4fs\n",(double)(end-start)/CLOCKS_PER_SEC);
  }

  if(predict_probability){
    
    prob_estimates0 = (double *) malloc(nr_class*sizeof(double));
    prob_estimates1 = (double *) malloc(nr_class*sizeof(double));
    prob_estimates2 = (double *) malloc(nr_class*sizeof(double));
    //the returned values should be probability estimates
    x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node));
    for(instance_index=0;instance_index<testing_instance_number;instance_index++){

      double v0,v1,v2; //target probability values for the three functions
      if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED){
	mexPrintf("Sparse Reading of data is not supported!");
	fake_answer(plhs);
	return;
      }
      else{
	for(int i=0;i<feature_number;i++){
	  
	  x[i].index = i+1;
	  x[i].value = ptr_instance[testing_instance_number*i+instance_index];
	}
	x[feature_number].index = -1;
      }

      v0 = svm_predict_probability2(model, x, prob_estimates0,&linHashVals[instance_index]);
      v1 = svm_predict_probability2(model, x, prob_estimates1,&pwConstVals[instance_index]);
      v2 = svm_predict_probability2(model, x, prob_estimates2,&pwLinVals[instance_index]);
      //copy the answer
      for(int i=0;i<nr_class;i++){
	ptr_prob_estimates0[instance_index + i* testing_instance_number] = prob_estimates0[i];
	ptr_prob_estimates1[instance_index + i* testing_instance_number] = prob_estimates1[i];
	ptr_prob_estimates2[instance_index + i* testing_instance_number] = prob_estimates2[i];
      }
    }
  }
  else{
    //copy the answer
    for(instance_index=0;instance_index<testing_instance_number;instance_index++){
      ptr_dec_values0[instance_index] = linHashVals[instance_index];
      ptr_dec_values1[instance_index] = pwConstVals[instance_index];
      ptr_dec_values2[instance_index] = pwLinVals[instance_index];
    }
  }
  //free
  delete intC;
  for(int i =0; i < testing_instance_number; i++)
    free(data[i]);
  free(data);
  if(prob_estimates0 != NULL){
    free(prob_estimates0);
    free(prob_estimates1);
    free(prob_estimates2);
    free(x);
  }
  free(linHashVals);
  free(pwConstVals);
  free(pwLinVals);
  //free(normalVals);
}
//outputs the predictions
void mexFunction( int nlhs, mxArray *plhs[],
		 int nrhs, const mxArray *prhs[] )
{
	int prob_estimate_flag = 0;
	int verbose_flag = 0; 
	int num_bins = 100 ; //default number of bins
	struct svm_model *model;

	if(nrhs > 4 || nrhs < 3)
	{
		exit_with_help();
		fake_answer(plhs);
		return;
	}
	if(mxIsStruct(prhs[2]))
	{
		const char *error_msg;
		if(nrhs==4)
		{
			int i, argc = 1;
			char cmd[CMD_LEN], *argv[CMD_LEN/2];

			// put options in argv[]
			mxGetString(prhs[3], cmd,  mxGetN(prhs[3]) + 1);
			if((argv[argc] = strtok(cmd, " ")) != NULL)
				while((argv[++argc] = strtok(NULL, " ")) != NULL)
					;

			for(i=1;i<argc;i++)
			{
				if(argv[i][0] != '-') break;
				if(++i>=argc)
				{
					exit_with_help();
					fake_answer(plhs);
					return;
				}
				switch(argv[i-1][1])
				{
					case 'b':
						prob_estimate_flag = atoi(argv[i]);
						break;
						//MAJI begin
					case 'v':
						verbose_flag = atoi(argv[i]);
						break;
					case 'n':
						num_bins = atoi(argv[i]);
						break;
						//MAJI end
					default:
						mexPrintf("unknown option\n");
						exit_with_help();
						fake_answer(plhs);
						return;
				}
			}
		}

		model = (struct svm_model *) malloc(sizeof(struct svm_model));
		error_msg = matlab_matrix_to_model(model, prhs[2]);
		if(error_msg)
		{
			mexPrintf("Error: can't read model: %s\n", error_msg);
			svm_destroy_model(model);
			fake_answer(plhs);
			return;
		}
		if(prob_estimate_flag)
			if(svm_check_probability_model(model)==0)
			{
				mexPrintf("Model does not support probabiliy estimates\n");
				fake_answer(plhs);
				svm_destroy_model(model);
				return;
			}

		predict(nlhs, plhs, prhs, model, num_bins, prob_estimate_flag,verbose_flag);
		svm_destroy_model(model);
	}
	else
	{
		mexPrintf("model file should be a struct array\n");
		fake_answer(plhs);
	}

	return;
}
