

//////////////////////////////////////////////////////////////////
//                                                              //
//           PLINK (c) 2005-2008 Shaun Purcell                  //
//                                                              //
// This file is distributed under the GNU General Public        //
// License, Version 2.  Please see the file COPYING for more    //
// details                                                      //
//                                                              //
//////////////////////////////////////////////////////////////////


#include <iostream>
#include <algorithm>
#include "plink.h"
#include "sets.h"
#include "options.h"
#include "helper.h"
#include "model.h"

extern Plink * PP;

Set::Set(vector<vector<int> > & ss) : snpset(ss) 
{
  sizeSets();
}

void Set::sizeSets()
{
  
  cur.resize(snpset.size());  
  for(int s=0;s<snpset.size();s++)
    cur[s].resize(snpset[s].size(),true);

  // Specific to SET-based tests

  if ( (par::assoc_test || par::TDT_test ) && par::set_test 
       && !par::hotel)
    {
      s_min.resize(snpset.size());
      s_max.resize(snpset.size());
      stat_set.resize(snpset.size());
      pv_set.resize(snpset.size());
      pv_maxG_set.resize(snpset.size());
      pv_maxE_set.resize(snpset.size());
      
      for(int i=0;i<snpset.size();i++)
	{
	  
	  // If no constraints given, then the 
	  // number of tests == size of set
	  
	  if (par::set_min==-1 ) s_min[i] = 0;
	  else if (par::set_min > snpset[i].size() )
	    s_min[i] = snpset[i].size();
	  else s_min[i] = par::set_min-1;
	  if (par::set_max==-1 || par::set_max > snpset[i].size() ) 
	    s_max[i] = snpset[i].size();
	  else s_max[i] = par::set_max;
	  if (s_min>s_max) s_min[i]=s_max[i];
	  
	  int s = (s_max[i] - s_min[i]);
	  
	  stat_set[i].resize(s); 
	  pv_set[i].resize(s);
	  pv_maxG_set[i].resize(s);
	  pv_maxE_set[i].resize(s);
	  
	  if ( ! par::set_score ) 
	    {
	      for (int j=0; j<s; j++)
		stat_set[i][j].resize(par::replicates+1);
	    }	      
	      for (int j=0; j<s; j++)
		pv_set[i][j].resize(par::replicates+1);
	    
	}
    }

}


//////////////////////////////////////////////////////
//                                                  //
//    Remove 0-sized sets                           //
//                                                  //
//////////////////////////////////////////////////////

void Set::pruneSets(Plink & P)
{

  int pruned=0;

  for(int i=0;i<snpset.size();i++)  
    {
      if (snpset[i].size() == 0) 
	{
	  snpset.erase(snpset.begin()+i);
	  P.setname.erase(P.setname.begin()+i);
	  i--;
	  pruned++;
	}
    }
  
  P.printLOG(int2str(pruned)+" sets removed (0 valid SNPs)\n");
 
  // Resize all the other set arrays
  sizeSets();

}


//////////////////////////////////////////////////////
//                                                  //
//    Prune sets based on multi-collinearity        //
//                                                  //
//////////////////////////////////////////////////////

void Set::pruneMC(Plink & P, bool disp,double VIF_threshold)
{

  P.printLOG("Pruning sets based on variance inflation factor\n");

  for (int s=0; s<snpset.size(); s++)
    {
      
      if (!par::silent)
	cout << s+1 << " of " << snpset.size() << " sets pruned           \r";
      int nss = snpset[s].size();

      vector<double> mean;         // Sample mean
      vector<vector<double> > var; // Covariance matrix
      
      vector<int> nSNP(0);
      for (int j=0; j<nss; j++)
	{
	  nSNP.push_back( snpset[s][j] );
	}

      // Calculate covariance matrix (full sample)
      // (sizes and populates mean and var)
      // this routine uses the 'flag' variable

      var = calcSetCovarianceMatrix(nSNP);
      

      // Perform VIF pruning, setting filters (S.cur[][])

      vector<bool> p = vif_prune(var,VIF_threshold);

      for (int i=0; i<nss; i++)
	if (!p[i]) cur[s][i]=false;

  }

  if (!par::silent)
    cout << "\n";
  
  if (disp)
    {
      ofstream SET1, SET2;
      string f = par::output_file_name + ".set.in";
      
      P.printLOG("Writing pruned-in set file to [ " + f + " ]\n");
      SET1.open(f.c_str(),ios::out);
      
      f = par::output_file_name + ".set.out";
      P.printLOG("Writing pruned-out set file to [ " + f + " ]\n");
      SET2.open(f.c_str(),ios::out);
      
      for (int s=0; s<snpset.size(); s++)
	{
	  
	  int nss = snpset[s].size();
	  
	  SET1 << P.setname[s] << "\n";
	  SET2 << P.setname[s] << "\n";
	  
	  for (int j=0; j<nss; j++)
	    {
	      if (cur[s][j])
		SET1 << P.locus[snpset[s][j]]->name << "\n";
	      else
		SET2 << P.locus[snpset[s][j]]->name << "\n";
	    }
	  
	  SET1 << "END\n\n";
	  SET2 << "END\n\n";
	}
      
      SET1.close();
      SET2.close();
    }
}


//////////////////////////////////////////////////////
//                                                  //
//    Remove SNPs not in any set                    //
//                                                  //
//////////////////////////////////////////////////////

void Set::dropNotSet(Plink & P)
{

  /////////////////////////////////////////////
  // Drop any SNPs that do not belong in a set

  vector<bool> drop(P.nl_all,true);
  for (int i=0;i<snpset.size();i++)
    for (int j=0; j < snpset[i].size(); j++)
      drop[snpset[i][j]] = false;
  P.deleteSNPs(drop);

  // And re-read the list of SNPs, as we now we have 
  // changed the SNP numbers

  P.readSet();

}


//////////////////////////////////////////////////////
//                                                  //
//    Create map of SNP number of set codes         //
//                                                  //
//////////////////////////////////////////////////////

void Set::initialiseSetMapping()
{

  setMapping.clear();

  for (int i=0;i<snpset.size();i++)
    for (int j=0; j < snpset[i].size(); j++)
      {

	int l = snpset[i][j];

	map<int,set<int> >::iterator si = setMapping.find(l);
	
	// Either we haven't yet seen the SNP...
	if ( si == setMapping.end() )
	  {
	    set<int> t;
	    t.insert(i);
	    setMapping.insert(make_pair(l,t));
	  }
	else
	  {
	    // ... or we have
	    si->second.insert(i);
	  }
	
	// Next SNP
      }

}


//////////////////////////////////////////////////////
//                                                  //
//    Sum-statistic scoring (original)              //
//                                                  //
//////////////////////////////////////////////////////

void Set::cumulativeSetSum_WITHLABELS(Plink & P, vector<double> & original)
{

  // Consider each set
  for (int i=0;i<snpset.size();i++)
    {
      
      vector<SETSORT> t;
      
      // Gather set of all chi-sqs (map sorts them automatically)
      for (int j=0; j < snpset[i].size(); j++)
	{
	  SETSORT s;
	  s.chisq = original[snpset[i][j]];
	  s.name = P.locus[snpset[i][j]]->name;
	  t.push_back(s);
	}	  
      
      
      // Sort t
      sort(t.begin(),t.end());
      
      // Store results for s_min through s_max     
      double s=0;
      int j=0;
      vector<string> t2;
      
      for( vector<SETSORT>::reverse_iterator p= t.rbegin(); p!=t.rend(); p++)
	{
	  s += p->chisq;
	  if (j>=s_min[i] && j<s_max[i])
	    {
	      stat_set[i][j-s_min[i]][0] = s/(double)(j+1);
	      t2.push_back(p->name);
	    }
	  j++;
	}
      
      // And save
      setsort.push_back(t2);

    }
}


//////////////////////////////////////////////////////
//                                                  //
//    Sum-statistic scoring (permutation)           //
//                                                  //
//////////////////////////////////////////////////////

void Set::cumulativeSetSum_WITHOUTLABELS(vector<double> & perm, int p)
{

  vector<double> t;
  
  // Consider each set
  for (int i=0;i<snpset.size();i++)
    {
      
      t.resize(0);
      
      // Gather set of chi-sqs
      for (int j=0;j<snpset[i].size();j++)
	t.push_back(perm[snpset[i][j]]);
      
      // Sort them
      sort(t.begin(),t.end());
      
      // Store
      double s=0;
      for (int j=0;j<s_max[i];j++)
	{
	  s += t[t.size()-1-j];
	  if (j>=s_min[i] && j<s_max[i])
	    {
	      stat_set[i][j-s_min[i]][p] = s/(double)(j+1);
	    }
	}

      
    } 
}


//////////////////////////////////////////////////////
//                                                  //
//    Sum-statistic empircal p-value calculation    //
//                                                  //
//////////////////////////////////////////////////////

void Set::empiricalSetPValues()
{
  
  int R = par::replicates;


  //////////////////////////////////////////////////
  // Basic p-values, for original and each replicate

  // For the j'th SNP of the i'th SET, calculate how many times
  // the other permutations exceed it (permutations 0 to R, where
  // 0 is the original result)
  
  for (int p0=0;p0<=R;p0++)   // index 
    for (int p1=0;p1<=R;p1++) // all other perms (including self)
      for (int i=0;i<stat_set.size();i++) 
	for (int j=0;j<stat_set[i].size();j++)
	  if (stat_set[i][j][p1] >= stat_set[i][j][p0] ) pv_set[i][j][p0]++; 
  
  // Find best p-values per rep (overall, per set)

  for (int p=0;p<=R;p++)
    {
      
      double maxE_set = 1;
      vector<double> maxG_set(pv_set.size(),1);
      
      // Consider each score
      for (int i=0;i<pv_set.size();i++) 
	for (int j=0;j<pv_set[i].size();j++)
	  {
	    // Make into p-value (will include self: i.e. N+1)
	    pv_set[i][j][p] /= R+1;
	    
	    if (pv_set[i][j][p] < maxG_set[i]) maxG_set[i] = pv_set[i][j][p];
	    if (pv_set[i][j][p] < maxE_set) maxE_set = pv_set[i][j][p];
	  }
      
      // Score max values
      for (int i=0;i<pv_set.size();i++) 
	for (int j=0;j<pv_set[i].size();j++)
	  {
	    if (maxG_set[i] <= pv_set[i][j][0]) pv_maxG_set[i][j]++;
	    if (maxE_set    <= pv_set[i][j][0]) pv_maxE_set[i][j]++;
	  }
    }
  
}




////////////////////////////////////////////////////////////////////
//                                                                //
//   Score-profile based test                                     //
//                                                                //
////////////////////////////////////////////////////////////////////

void Set::profileTestSNPInformation(int l, double odds)
{

  // If we are passed a SNP here, it is because it significant
  // at the specified par::set_score_p threshold
  
  // We have to ask: does this SNP belong to one or more sets?  
  // If so, store the SNP number and odds ratio, for each set (i.e. 
  // build up a profile to score; we do not need to save allele, as 
  // it is always with reference to the minor one
  
  map<int, set<int> >::iterator si = setMapping.find(l);

  if ( si == setMapping.end() ) 
    {
      return;
    }

  set<int>::iterator li = si->second.begin();
  
  while ( li != si->second.end() )
    {
      profileSNPs[ *li ].push_back( l );
      profileScore[ *li ].push_back( odds );
      ++li;
    }

}




vector_t Set::profileTestScore()
{

  ///////////////////////////////////////////////////
  // For each set, calculate per-individual scores, then 
  // regress this on the phenotype, then save a Wald 
  // test statistic

  vector_t results;

  for (int i=0; i<snpset.size(); i++)
    {

      vector_t profile;
      vector<int> count;
      
      map<int,double> scores;
      map<int,bool> allele1;
      
      for (int j=0; j<profileSNPs[i].size(); j++)
	{	  
	  scores.insert(make_pair( profileSNPs[i][j], profileScore[i][j] ));
	  allele1.insert(make_pair( profileSNPs[i][j], false ));
	}
      
      // Record set size (# significant SNPs; use set_min to store this)

      s_min[i] = profileSNPs[i].size();
      

      ///////////////////////////////
      // Any significant SNPs?
	
      if ( scores.size() == 0 ) 
        {
          // Record a null score
          results.push_back( 0 );
          continue;
        }


      ////////////////////////////////
      // Calculate actual profile

      PP->calculateProfile(scores, allele1, profile, count);
      

      ///////////////////////////////////////////////
      // Save as the covariate, the mean score (i.e. 
      // average by number of seen SNPs)
      
      for (int k=0; k < PP->n; k++)
	{
	  
	  Individual * person = PP->sample[k];
	  
	  if ( count[k] == 0 || person->flag ) 
	    person->missing = true;
	  else
	    {
	      person->clist[0] = profile[k] / (double)count[k];
	      person->missing = false;
	    }
	  
	}
      
      
      ////////////////////////////////
      // Regress phenotype on profil

      PP->glmAssoc(false,*PP->pperm);


      //////////////////////////////////////////////
      // Reset original missing status

      vector<Individual*>::iterator i = PP->sample.begin();
      while ( i != PP->sample.end() )
	{
	  (*i)->missing = (*i)->flag;
	  ++i;
	}

      ////////////////////////////////////////////////
      // Save test statistic for permutation purposes
      
      double statistic = PP->model->getStatistic();
      
      PP->model->validParameters();

      if ( ! PP->model->isValid() ) 
	statistic = -1;

      results.push_back( statistic );
      
  
      ////////////////////////////////////////////////
      // Clear up GLM model
      
      delete PP->model;
      
    }

  
  // Finally, important to clear the profile scores now, 
  // so that the next permutation starts from scratch
  
  profileSNPs.clear();
  profileScore.clear();

  profileSNPs.resize( snpset.size() );
  profileScore.resize( snpset.size() );

  return results;
}
 

void Set::profileTestInitialise()
{
  
  PP->printLOG("Initalising profile-based set test\n");

  // Set up the mapping to determine which set(s) 
  // a given SNP is in

  initialiseSetMapping();

  // Clear the scores
  
  profileSNPs.clear();
  profileScore.clear();

  profileSNPs.resize( snpset.size() );
  profileScore.resize( snpset.size() );


  ///////////////////////////////////////////////////
  // Set-up for use of the Linear or Logistic Models

  par::assoc_glm_without_main_snp = true;  
  
  if ( PP->clistname.size() > 0 ) 
    error("Cannot specify covariates with --set-score");


  //////////////////////////////////////////////
  // Use flag to store original missing status

  vector<Individual*>::iterator i = PP->sample.begin();
  while ( i != PP->sample.end() )
    {
      (*i)->flag = (*i)->missing;
      ++i;
    }

  /////////////////////////////////
  // Pretend we have covariates

  par::clist = true;
  par::clist_number = 1;

  PP->clistname.resize(1);
  PP->clistname[0] = "PROFILE";
  
  for (int i=0; i< PP->n; i++)
    {
      Individual * person = PP->sample[i];      
      person->clist.resize(1);
    }  

}
