//LabPlot : RegressionListDialog.cc

#include "RegressionListDialog.h"
#include "regression.h"

#ifdef HAVE_GSL
#include <gsl/gsl_fit.h>
#include <gsl/gsl_multifit.h>
#endif

using namespace std;

RegressionListDialog::RegressionListDialog(MainWin *mw, const char *name)
	: ListDialog(mw, name)
{
	setCaption(i18n("Regression Dialog"));

	Plot *plot = p->getPlot(p->API());

	QTabWidget *tw = new QTabWidget(vbox);
	QVBox *tab1 = new QVBox(tw);

	QHBox *hb = new QHBox(tab1);
	regioncb = new QCheckBox(i18n("use Region "),hb);
	if(plot->RegionMin() != plot->RegionMax() )
		regioncb->setChecked(true);
	else
	regioncb->setChecked(false);
	new QLabel(i18n("( From "),hb);
	regionminle = new KLineEdit(QString::number(plot->RegionMin()),hb);
	regionminle->setValidator(new QDoubleValidator(regionminle));
	new QLabel(i18n(" To "),hb);
	regionmaxle = new KLineEdit(QString::number(plot->RegionMax()),hb);
	regionmaxle->setValidator(new QDoubleValidator(regionmaxle));
	new QLabel(i18n(" )"),hb);

	hb = new QHBox(tab1);
	new QLabel(i18n("Model : "),hb);
        modelcb = new KComboBox(hb);
	int i=0;
	while(modelitems[i] != 0) modelcb->insertItem(i18n(modelitems[i++]));
        modelcb->setCurrentItem(0);

	hb = new QHBox(tab1);
	new QLabel(i18n("Weight : "),hb);
	weightcb = new KComboBox(hb);
	i=0;
	while(weightitems[i] != 0) weightcb->insertItem(i18n(weightitems[i++]));
	weightcb->setCurrentItem(0);
	QObject::connect(weightcb,SIGNAL(activated(int)),SLOT(weightChanged()));
	hb = new QHBox(tab1);
	new QLabel(i18n("Weight Function : "),hb);
	weightle = new KLineEdit(i18n("equal"),hb);
	weightle->setReadOnly(true);

	hb = new QHBox(tab1);
	new QLabel(i18n("Number of Points for regression function : "),hb);
	GraphList *gl = plot->getGraphList();
	GRAPHType st = gl->getStruct(0);
	int number=100;
	if (st == GRAPH2D) {
		Graph2D *g = gl->getGraph2D(0);
		number = g->Number();
	}
	numberle = new KLineEdit(QString::number(number),hb);
	numberle->setValidator(new QIntValidator(numberle));

	hb = new QHBox(tab1);
	new QLabel(i18n("Range of regression function : "),hb);
	LRange *range = plot->Ranges();
	minle = new KLineEdit(QString::number(range[0].rMin()),hb);
	minle->setValidator(new QDoubleValidator(minle));
	new QLabel(i18n(" .. "),hb);
	maxle = new KLineEdit(QString::number(range[0].rMax()),hb);
	maxle->setValidator(new QDoubleValidator(maxle));

	hb = new QHBox(tab1);
	infocb = new QCheckBox(i18n("Show Info"),hb);
	infocb->setChecked(true);
	rescb = new QCheckBox(i18n("Show Residuals"),hb);
	rescb->setChecked(false);

	Style *style=0;
	Symbol *symbol=0;
	QVBox *styletab;
	if(p->getPlot(p->API())->Type() == PSURFACE)
		styletab = surfaceStyle(tw,true);
	else
		styletab = simpleStyle(tw, style, symbol);

	tw->addTab(tab1,i18n("Parameter"));
	tw->addTab(styletab,i18n("Style"));

	QObject::connect(ok,SIGNAL(clicked()),SLOT(ok_clicked()));
        QObject::connect(apply,SIGNAL(clicked()),SLOT(apply_clicked()));

	setMinimumWidth(vbox->minimumSizeHint().width());
	setMinimumHeight(gbox->minimumSizeHint().height()+vbox->minimumSizeHint().height());
	resize(minimumSize());
}

void RegressionListDialog::weightChanged() {
	kdDebug()<<"RegressionListDialog::weightChanged()"<<endl;
	int item = weightcb->currentItem();
	
	weightle->setText(i18n(weightitems[item]));
	if(item==WUSER)
		weightle->setReadOnly(false);
	else
		weightle->setReadOnly(true);
}

void RegressionListDialog::setWeightFunction(QString w) { 
	weightcb->setCurrentItem(WUSER); 
	weightle->setText(w); 
}

int RegressionListDialog::apply_clicked() {
#ifdef HAVE_GSL
	Plot *plot = p->getPlot(p->API());
	GraphList *gl = plot->getGraphList();
	if(gl->Number()==0) {
		KMessageBox::error(this,i18n("No graph found!"));
		return -2;
	}
	int item = (int) (lv->itemPos(lv->currentItem())/lv->currentItem()->height());

	GRAPHType st = gl->getStruct(item);
	QString info;

	int numberx = numberle->text().toInt();
	Point *ptr = new Point[numberx];

	// 2d : x-y , 3d : x-y-dy , 4d : x-y-dx-dy
	// TODO : matrix, x-y-z, image
	if (st == GRAPH2D || st == GRAPH3D || st == GRAPH4D ) {
		Graph *graph = gl->getGraph(item);
		
		// number->nx : use only unmasked points
		int number=graph->Number(),nx=0;
		Point *data2d=new Point[number];
		Point3D *data3d=new Point3D[number];
		Point4D *data4d=new Point4D[number];
		bool g4type=0;
		if(st == GRAPH2D) {
			Graph2D *g = gl->getGraph2D(item);
			for(int i=0;i<number;i++) {
				if(!g->Data()[i].Masked())
					data2d[nx++] = g->Data()[i];
			}
		}
		else if (st == GRAPH3D) {
			Graph3D *g = gl->getGraph3D(item);
			for(int i=0;i<number;i++) {
				if(!g->Data()[i].Masked())
					data3d[nx++] = g->Data()[i];
			}
		}
		else if (st == GRAPH4D) {
			Graph4D *g = gl->getGraph4D(item);
			for(int i=0;i<number;i++) {
				if(!g->Data()[i].Masked())
					data4d[nx++] = g->Data()[i];
			}
			g4type = g->GType();
		}
		QString fun;
		int N=0;

		double xmin=0, xmax=1, ymin=0, ymax=1;
		if (modelcb->currentItem() == MLINEAR) {
			double* xdata = new double[nx];
			double* ydata = new double[nx];
			double* weight = new double[nx];
			double c0, c1, cov00, cov01, cov11, chisq;
			double sum=1;

			// create weight
			for (int i=0;i<nx;i++) {
				weight[i]=0.0;
				if (weightcb->currentItem() == WY)
					weight[i]=data2d[i].Y();
				else if(weightcb->currentItem() == WYY)
					weight[i]=data2d[i].Y()*data2d[i].Y();
				else if(weightcb->currentItem() == W1Y) {
					if (data2d[i].Y() != 0)
						weight[i]=1.0/data2d[i].Y();
				}
				else if(weightcb->currentItem() == W1YY) {
					if (data2d[i].Y() != 0)
						weight[i]=1.0/(data2d[i].Y()*data2d[i].Y());
				}
				else if(weightcb->currentItem() == WX)
					weight[i]=data2d[i].X();
				else if(weightcb->currentItem() == WXX)
					weight[i]=data2d[i].X()*data2d[i].X();
				else if(weightcb->currentItem() == W1X) {
					if (data2d[i].X() != 0)
						weight[i]=1.0/data2d[i].X();
				}
				else if(weightcb->currentItem() == W1XX) {
					if (data2d[i].X() != 0)
						weight[i]=1.0/(data2d[i].X()*data2d[i].X());
				}
				else if(weightcb->currentItem() == WERROR) {
					double e=1.0;
					if(st == GRAPH3D)
						e = data3d[i].Z();
					else if (st == GRAPH4D) {
						if(g4type)
							e = data4d[i].Z()+data4d[i].T();	// x-y-dx-dy
						else
							e = data4d[i].T();	// x-y-dx-dy
					}
					if(e>1.0e-15)
						weight[i]=1.0/(e*e);
				}
				else if(weightcb->currentItem() == WUSER) {
					QString tmp;
					if(st == GRAPH2D) {
						tmp = mw->parseExpression(weightle->text(), data2d[i].X(), 23);  // "x"
						tmp = mw->parseExpression(tmp, data2d[i].Y(), 24);  // "y"
					}
					else if(st == GRAPH3D) {
						tmp = mw->parseExpression(weightle->text(), data3d[i].X(), 23);  // "x"
						tmp = mw->parseExpression(tmp, data3d[i].Y(), 24);  // "y"
					}
					else if(st == GRAPH4D) {
						tmp = mw->parseExpression(weightle->text(), data4d[i].X(), 23);  // "x"
						tmp = mw->parseExpression(tmp, data4d[i].Y(), 24);  // "y"
					}

					double value = parse((char *) tmp.latin1());
					if(parse_errors()>0) {
						KMessageBox::error(mw, i18n("Parse Error!\n Please check the given weight function."));
						return -3;
					}

					if(!finite(value))
						value=0;
					weight[i]=value;
				}

				sum += weight[i];
			}

			// create data
			for (int i=0;i<nx;i++) {
				double x=0,y=0;
				if(st == GRAPH2D) {
					x = data2d[i].X();
					y = data2d[i].Y();
				}
				else if(st == GRAPH3D) {
					x = data3d[i].X();
					y = data3d[i].Y();
				}
				else if(st == GRAPH4D) {
					x = data4d[i].X();
					y = data4d[i].Y();
				}
				if(!regioncb->isChecked() || x > regionminle->text().toDouble() && x < regionmaxle->text().toDouble()) {
					xdata[N]=x;
					ydata[N]=y;
					
					// dont normalize weighting (weight[N]=weight[i]/sum;)
					weight[N]=weight[i];

					N++;
				}
			}

			// fit
			if (weightcb->currentItem() == WEQUAL)
				gsl_fit_linear (xdata, 1, ydata,1, N,&c0,&c1,&cov00,&cov01,&cov11,&chisq);
			else
				gsl_fit_wlinear (xdata, 1, weight, 1, ydata,1, N,&c0,&c1,&cov00,&cov01,&cov11,&chisq);

			// info
			info += i18n("best fit:")+" y = "+QString::number(c0)+" + "+QString::number(c1)+" x\n";
			info += i18n("covariance matrix:")+"\n";
			info += "  [ " +QString::number(cov00)+", "+QString::number(cov01)+"\n";

			info += "    " +QString::number(cov01)+", "+QString::number(cov11)+" ]\n";
			info += "chi^2 = " + QString::number(chisq);

			double rangemin=minle->text().toDouble();
			double rangemax=maxle->text().toDouble();
			// reset values for residuals
			if(rescb->isChecked()) {
				LRange *range = plot->Ranges();
				numberx=nx;
				rangemin=range[0].rMin();
				rangemax=range[0].rMax();
			}
			for (int i = 0;i<numberx;i++) {
				double x = rangemin+i*(rangemax-rangemin)/(double)(numberx-1);
				double y = c0 + c1*x;
				
				if(rescb->isChecked())
					y = ydata[i]-y;

				ptr[i].setPoint(x,y);
			}
			fun = "y = "+QString::number(c0)+" + "+QString::number(c1)+" x";
		}
		else {	// non linear
			int order = modelcb->currentItem()+2;	// item -> order
			double chisq, sum=0;
			gsl_matrix *X, *cov;
			gsl_vector *yy, *w, *c;

			#define C(i) (gsl_vector_get(c,(i)))
			#define W(i) (gsl_vector_get(w,(i)))
			#define XX(i) (gsl_vector_get(X,(i)))
			#define COV(i,j) (gsl_matrix_get(cov,(i),(j)))

			w = gsl_vector_alloc (nx);

			//set weight
			for (int i=0;i<nx;i++) {
				double x=0,y=0;
				if (st == GRAPH2D) {
					x = data2d[i].X();
					y = data2d[i].Y();
				}
				else if (st == GRAPH3D) {
					x = data3d[i].X();
					y = data3d[i].Y();
				}
				else if (st == GRAPH4D) {
					x = data4d[i].X();
					y = data4d[i].Y();
				}
				if(!regioncb->isChecked() || x > regionminle->text().toDouble() && x < regionmaxle->text().toDouble()) {
					if (weightcb->currentItem() == WY)
						gsl_vector_set (w, N, y);
					else if(weightcb->currentItem() == WYY)
						gsl_vector_set (w, N, y*y);
					else if(weightcb->currentItem() == W1Y) {
						if (y != 0)
							gsl_vector_set (w, N, 1.0/y);
						else
							gsl_vector_set (w, N, 0.0);
					}
					else if(weightcb->currentItem() == W1YY) {
						if (y != 0)
							gsl_vector_set (w, N, 1.0/(y*y));
						else
							gsl_vector_set (w, N, 0.0);
					}
					else if(weightcb->currentItem() == WX)
						gsl_vector_set (w, N, x);
					else if(weightcb->currentItem() == WXX)
						gsl_vector_set (w, N, x*x);
					else if(weightcb->currentItem() == W1X) {
						if (x != 0)
							gsl_vector_set (w, N, 1.0/x);
						else
							gsl_vector_set (w, N, 0.0);
					}
					else if(weightcb->currentItem() == W1XX) {
						if (x != 0)
							gsl_vector_set (w, N, 1.0/(x*x));
						else
							gsl_vector_set (w, N, 0.0);
					}
					else if(weightcb->currentItem() == WERROR) {
						double e=1.0;
						if(st == GRAPH3D) {
							e = data3d[i].Z();
						}
						else if(st == GRAPH4D) {
							if(g4type) // x-y-dy1-dy2
								e = data4d[i].Z()+data4d[i].T();
							else // x-y-dx-dy
								e = data4d[i].T();
						}
						gsl_vector_set (w, N, 1/(e*e));
					}
					else if(weightcb->currentItem() == WUSER) {
						QString tmp;
						if(st == GRAPH2D) {
							tmp = mw->parseExpression(weightle->text(), data2d[i].X(), 23);  // "x"
							tmp = mw->parseExpression(tmp, data2d[i].Y(), 24);  // "y"
						}
						else if(st == GRAPH3D) {
							tmp = mw->parseExpression(weightle->text(), data3d[i].X(), 23);  // "x"
							tmp = mw->parseExpression(tmp, data3d[i].Y(), 24);  // "y"
						}
						else if(st == GRAPH4D) {
							tmp = mw->parseExpression(weightle->text(), data4d[i].X(), 23);  // "x"
							tmp = mw->parseExpression(tmp, data4d[i].Y(), 24);  // "y"
						}
						
						double value = parse((char *) tmp.latin1());
						if(parse_errors()>0) {
							KMessageBox::error(mw, i18n("Parse Error!\n Please check the given weight function."));
							return -3;
						}

						if(!finite(value))
							value=0;
						gsl_vector_set (w, N, value);
					}

					sum +=	W(N);
					N++;
				}
			}

			X = gsl_matrix_alloc (N, order);
			yy = gsl_vector_alloc (N);
			c = gsl_vector_alloc (order);
			cov = gsl_matrix_alloc (order, order);

			double ydata[N];	// for residuals
			N=0;
			for (int i = 0; i < nx; i++) {
				double x=0,y=0;
				if (st == GRAPH2D) {
					x = data2d[i].X();
					y = data2d[i].Y();
				}
				else if (st == GRAPH3D) {
					x = data3d[i].X();
					y = data3d[i].Y();
				}
				else if (st == GRAPH4D) {
					x = data4d[i].X();
					y = data4d[i].Y();
				}
				if(!regioncb->isChecked() || x > regionminle->text().toDouble() && x < regionmaxle->text().toDouble()) {
					for (int j=0;j<order;j++)
						gsl_matrix_set(X,N,j,pow(x,j));

				 	gsl_vector_set (yy, N, y);
				 	gsl_vector_set (w, N, W(N)/sum);
					ydata[N]=y;
					N++;
				}
			}

			gsl_multifit_linear_workspace * work = gsl_multifit_linear_alloc (N, order);
			if (weightcb->currentItem() == WEQUAL)
				gsl_multifit_linear (X, yy, c, cov,&chisq, work);
			else
				gsl_multifit_wlinear (X, w, yy, c, cov,&chisq, work);
			gsl_multifit_linear_free(work);

			info += i18n("best fit:")+" y = "+QString::number(C(0))+" + "+QString::number(C(1))+" x ";
			for (int i=2;i<order;i++)
				info += " + "+ QString::number(C(i)) + " x^" + QString::number(i);
			info += "\n"+i18n("covariance matrix:")+"\n";
			for (int j=0;j<order;j++) {
				for (int i=0;i<order;i++)
					info += QString::number(COV(j,i))+",";
				info +="\n";
			}
			info += "chi^2 = " + QString::number(chisq);

			double rangemin=minle->text().toDouble();
			double rangemax=maxle->text().toDouble();
			// reset values for residuals
			if(rescb->isChecked()) {
				LRange *range = plot->Ranges();
				numberx=N;
				rangemin=range[0].rMin();
				rangemax=range[0].rMax();
			}
			for (int i = 0;i<numberx;i++) {
				Axis *axis = plot->getAxis(0);
				double y=0,x=0;
				switch(axis->Scale()) {
				case LINEAR:
				case SQRT:
				case SX2:
					x =rangemin+i*(rangemax-rangemin)/(double)(numberx-1);
					break;
				case LOG10:
					x = rangemin+pow((double)i,10)*(rangemax-rangemin)/pow((double)(numberx-1),10);
					break;
				case LOG2:
					x = rangemin+pow((double)i,2)*(rangemax-rangemin)/pow((double)(numberx-1),2);
					break;
				case LN:
					x = rangemin+pow((double)i,M_E)*(rangemax-rangemin)/pow((double)(numberx-1),M_E);
					break;
				}
				
				for (int j=0;j<order;j++)
					y += C(j)*pow(x,j);

				if(rescb->isChecked())
					y = ydata[i]-y;

				ptr[i].setPoint(x,y);
			}
			fun = i18n("%1th order regression").arg(QString::number(order-1));
		}

		if(rescb->isChecked())
			fun.prepend(i18n("residuals of "));

		mw->calculateRanges2D(ptr,numberx,&xmin,&xmax,&ymin,&ymax);

		// create the new graph
		LRange range[2];
		range[0] = LRange(xmin,xmax);
		range[1] = LRange(ymin,ymax);

		Style *style = new Style(cb2->currentItem(),color->color(),filled->isChecked(),fcolor->color(),
			widthle->text().toInt(),pencb->currentItem(),brushcb->currentItem());
		Symbol *symbol = new Symbol((SType)symbolcb->currentItem(),scolor->color(),ssize->text().toInt(),
			(FType)symbolfillcb->currentItem(),sfcolor->color(),sbrushcb->currentItem());

		Graph2D *ng = new Graph2D(fun,fun,range,SSPREADSHEET,P2D,style,symbol,ptr,numberx);
		mw->addGraph2D(ng,sheetcb->currentItem());
	}

	updateList();

	result=info;
	if (infocb->isChecked())
		KMessageBox::information(0,info);
#else
		KMessageBox::error(this, i18n("Sorry. Your installation doesn't support the GSL!"));
#endif

	return 0;
}
