/* mainSVMtrain.cpp
 *
 * Copyright (C) 2014 Mahmoud Ghandi
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include<iostream>

#include "global.h"
#include "globalvar.h"

#include "Sequence.h"
#include "CountKLmers.h"
#include "CountKLmersGeneral.h"
#include "CountKLmersH.h"
//#include "CountKLmers3G.h"
#include "CalcWmML.h"
#include "MLEstimKLmers.h"
//#include "MLEstimKLmers3G.h"
#include "MLEstimKLmersLog.h"
#include "KLmer.h"
#include "SequenceNames.h"
#include "EstimLogRatio.h"
#include "LTree.h"
#include "LTreef.h"
#include "LTreeS.h"
#include "LList.h"
#include "SVMtrain.h"

using namespace std;

#define NMAXNSEQUENCES 1000001
//#include "stdafx.h"

int main(int argc, char * argv[]) //mainSVMtrain
{
	char *kernelFN = NULL; 
	char *posSeqFN = NULL;
	char *negSeqFN = NULL;
	char *outFN = NULL; 
	int niter20=5; 
    
	int maxslen =1000000; 
    
    
	int i; 
	i=1; 
	int nfp = 0; // fixed parameters read
	int perr = 0; 
	
	while ((i<argc)&& (!perr))
	{
		if (argv[i][0]=='-')
		{
            //	[-n <niter20>]\n");
			if (stringcompare(argv[i], "-n",2))
			{
				niter20 = atoi(argv[i+1]); 
				i++; 
			}
			else
			{
				printf("\n parameter not recognized: %s \n", argv[i]); 
				perr = 1; 
			}
		}
		else
		{
			if (nfp==0) // first param 
			{
				kernelFN=argv[i];
			}
			else if (nfp==1)
			{
				posSeqFN=argv[i];
			}
			else if (nfp==2)
			{
				negSeqFN=argv[i];
			}
			else if (nfp==3)
			{
				outFN=argv[i];
			}
			else
			{
				printf("\n parameter not recognized: %s \n", argv[i]); 
				perr = 1; 
			}
			nfp++; 
		}
		i++; 
	}

	if (nfp!=4) { perr = 1; }
    
	if (perr)
	{
		printf("\n");
		printf(" Usage: ./SVMtrain [-n <niter20>] <kernel_file> <pos_seqfile> <neg_seqfile> <out_prefix>\n"); 
		printf("\n");
		printf("  given kernel matrix, computes lambdas for each sequence.\n");
		printf("  (using iterative method of \"A discriminative framework for detecting\n");
		printf("  remote protein homologies. Jaakkola T, Diekhans M, Haussler D., 2000\")\n");
		printf("\n");
		printf(" Arguments:\n");
		printf("  kernel_file: kernel matrix file generated by gkmKernel\n");
		printf("  pos_seqfile: positive sequence file used to generate the kernelFN\n");
		printf("  neg_seqfile: negative sequence file used to generate the kernelFN\n");
		printf("  out_prefix: prefix of output file names. There are two output files;\n");
		printf("              one for alphas ({PREFIX}_svalpha.out) and the other for\n");
		printf("              the corresponding sequences ({PREFIX}_svseq.fa).\n");
		printf("\n");
		printf(" Options:\n");
		printf("  niter20: number of iterations divided by 20, default=%d\n",niter20);
		printf("\n");
		printf(" Examples:\n");
		printf("  ./SVMtrain -n 10 kernel posseq.fa negseq.fa svmtrain\n");
		printf("  ./SVMtrain kernel posseq.fa negseq.fa svmtrain\n");
		printf("\n");
		return 0; 
	} 

	CSequence *sgi = new CSequence(100000);

	char **seqname = new char*[NMAXNSEQUENCES]; 
	char **seqs = new char*[NMAXNSEQUENCES];
	int npos, nneg, nseqs; 

	//read positive sequence file
	nseqs = 0;
	FILE *sfi = fopen(posSeqFN, "r"); 
	if (sfi == NULL)
	{
		perror ("error occurred while opening a file");
		return 0;
	}

	while (!feof(sfi))
	{
		sgi->readFsa(sfi,true);
		if(sgi->getLength()>0)
		{
			seqname[nseqs] = new char[100]; 
			sprintf(seqname[nseqs], sgi->getName()); 
			seqs[nseqs] = new char[sgi->getLength()+1];
			sprintf(seqs[nseqs], sgi->getSeq());
			nseqs++;
		}
	}
	fclose(sfi); 
	npos = nseqs;

	//read negative sequence file
	sfi = fopen(negSeqFN, "r"); 
	while (!feof(sfi))
	{
		sgi->readFsa(sfi,true);
		if(sgi->getLength()>0)
		{
			seqname[nseqs] = new char[100]; 
			sprintf(seqname[nseqs], sgi->getName()); 
			seqs[nseqs] = new char[sgi->getLength()+1];
			sprintf(seqs[nseqs], sgi->getSeq());
			nseqs++;
		}
	}
	fclose(sfi); 
	nneg = nseqs - npos;
	char *sline = new char[maxslen+2]; 
	int N = nseqs;

	printf("npos=%d, nneg=%d, N=%d\n",npos, nneg, N);
    
	double **kernel = new double *[N]; 
	for(i=0;i<N;i++)
	{
		kernel[i] = new double [N]; 
	}

	FILE *fi = fopen(kernelFN,"r"); 
	if (fi == NULL)
	{
		perror ("error occurred while opening a file");
		return 0;
	}

	for(i=0;i<N;i++)
	{
		if (TALK){
			printf("%d.",i+1);
		}
		/*
		fgets(sline, maxslen,fi); 
		kernel[i][i]=1.0; 
		char *sl = sline; 
		while ((*sl==' ')||(*sl=='\t')){sl++;}
		while (!((*sl==' ')||(*sl=='\t'))) {if (*sl==0) {printf("\nerror reading kernel, line %d\n",i); return(1);} sl++;} //class
        
		for(int j=0;j<i; j++)
		{
			while ((*sl==' ')||(*sl=='\t')){sl++;}
			while (!((*sl==' ')||(*sl=='\t'))) {if (*sl==0) {printf("\nerror reading kernel, line %d\n",i); return(1);} sl++;} //go forward one token
			while ((*sl==' ')||(*sl=='\t')){sl++;}
			sscanf(sl, "%lf", &(kernel[i][j])); 
			kernel[j][i] = kernel[i][j]; 
			sl++; 
		}
		*/
		for(int j=0;j<=i; j++)
		{
			int ret = fscanf(fi, "%lf", &(kernel[i][j])); 
			if ((ret == EOF) || (ret != 1)) {printf("\nerror reading kernel, line %d\n",i); return(1);}
			kernel[j][i] = kernel[i][j]; 
		}
	}
	printf("\n");
	fclose(fi); 
    
	CSVMtrain *svmtr= new CSVMtrain(); 
    
	double *lambdas = new double [N]; 
    
	svmtr->niter20 = niter20; 
	svmtr->train(kernel, npos, nneg, lambdas); 
    
	char *alphaFN = new char[strlen(outFN)+30];
	char *svFN = new char[strlen(outFN)+30];
	sprintf(alphaFN, "%s_svalpha.out", outFN);
	sprintf(svFN, "%s_svseq.fa", outFN);

	FILE *fo_alpha = fopen(alphaFN, "w"); 
	FILE *fo_sv = fopen(svFN, "w");
	if (fo_alpha == NULL)
	{
		perror ("error occurred while opening a file");
		return 0;
	}
	if (fo_sv == NULL)
	{
		perror ("error occurred while opening a file");
		return 0;
	}
	for(i=0;i<npos;i++)
	{
		if (lambdas[i]>1e-10)
		{
            fprintf(fo_alpha, "%s\t%e\n", seqname[i], lambdas[i]); 
			fprintf(fo_sv, ">%s\n%s", seqname[i], seqs[i]);
		}
	}
	for(i=npos;i<N;i++)
	{
		if (lambdas[i]>1e-10)
		{
            fprintf(fo_alpha, "%s\t%e\n", seqname[i], -lambdas[i]); 
			fprintf(fo_sv, ">%s\n%s", seqname[i], seqs[i]);

		}
	}
    
	fclose(fo_alpha); 
	fclose(fo_sv);
    
	delete svmtr;
	delete []lambdas; 
    
	delete []sline; 
	for(i=0;i<N;i++)
	{
		delete []kernel[i]; 
		delete []seqname[i]; 
		delete []seqs[i];
	}
	delete []kernel;
	delete []seqname; 
	delete []seqs;
	
	delete alphaFN;
	delete svFN;

	return 0; 
}
