# -------------------------------------------------------------------------
#     Copyright (C) 2005-2011 Martin Strohalm <www.mmass.org>

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 3 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE.TXT in the
#     main directory of the program
# -------------------------------------------------------------------------

# load libs
import math
import numpy
from numpy.linalg import solve as solveLinEq

# load stopper
from stopper import CHECK_FORCE_QUIT

# register essential objects and modules
import objects
import processing


# ENVELOPE FIT
# ------------

class envelopeFit():
    """Fit modeled profiles with exchanged atoms to acquired data."""
    
    def __init__(self, formula, charge, scale, loss='H', gain='H{2}'):
        
        self.error = None # 1 - no data in relevant mass range
        
        loss = objects.compound(loss)
        loss.negate()
        self._lossFormula = loss.formula()
        self._gainFormula = gain
        
        self.formula = formula
        self.charge = charge
        self.scale = self._validateScale(scale)
        self.mzRange = self._calcMZRange()
        
        self.fwhm = 0.1
        self.spectrum = []
        self.data = []
        self.model = []
        self.params = None
        self.composition = None
        self.ncomposition = None
        self.average = None
        self.std = None
    # ----
    
    
    def fitToSpectrum(self, points, fwhm=0.1, forceFwhm=True, autoAlign=True, iterLimit=None, pickingHeight=0.85, relThreshold=0., baselineWindow=0.1, baselineSmooth=True, baselineOffset=0.25, baselineData=None):
        """Fit modeled profiles to spectrum using tmp peaklist.
            points: (numpy array) m/z - intensity pairs or raw spectrum
            fwhm: (float) defaut fwhm
            forceFwhm: (bool) use default fwhm
            autoAlign: (bool) automatic m/z shift
            iterLimit: (int) maximum number of iterations
            pickingHeight: (float) peak picking height for centroiding
            relThreshold: (float) relative intensity threshold
            baselineWindow: (float) noise calculation window in %/100
            baselineSmooth: (bool) smooth baseline
            baselineOffset: (float) baseline intensity offset in %/100
            baselineData: (list of [x, noiseLevel, noiseWidth]) precalculated baseline
        """
        
        # get spectrum baseline
        if baselineData == None:
            baselineData = processing.baseline(
                points = points,
                window = baselineWindow,
                smooth = baselineSmooth,
                offset = baselineOffset
            )
        
        # crop points to relevant m/z range
        i1 = processing.findIndex(points, self.mzRange[0], dim=2)
        i2 = processing.findIndex(points, self.mzRange[1], dim=2)
        points = points[i1:i2]
        
        # get peaklist from spectrum points
        peaklist = processing.labelScan(
            points = points,
            pickingHeight = pickingHeight,
            relThreshold = relThreshold,
            baselineData = baselineData
        )
        
        # correct spectrum baseline
        self.spectrum = processing.correctBaseline(points, baselineData=baselineData)
        
        # fit to peaklist
        self.fitToPeaklist(
            peaklist = peaklist,
            fwhm = fwhm,
            forceFwhm = forceFwhm,
            autoAlign = autoAlign,
            iterLimit = iterLimit
        )
    # ----
    
    
    def fitToPeaklist(self, peaklist, fwhm=0.1, forceFwhm=True, autoAlign=True, iterLimit=None):
        """Fit modeled profiles to peaklist.
            peaklist: (mspy.peaklist) peak list
            fwhm: (float) defaut fwhm
            forceFwhm: (bool) use default fwhm
            autoAlign: (bool) automatic m/z shift
            iterLimit: (int) maximum number of iterations
        """
        
        # check peaklist object
        if not isinstance(peaklist, objects.peaklist):
            peaklist = objects.peaklist(peaklist)
        
        # crop peaklist to relevant m/z range
        buff = []
        for peak in peaklist:
            if self.mzRange[0] <= peak.mz <= self.mzRange[1]:
                buff.append(peak)
        peaklist = objects.peaklist(buff)
        
        # get fwhm from basepeak
        if not forceFwhm and peaklist.basepeak and peaklist.basepeak.fwhm:
            fwhm = peaklist.basepeak.fwhm
        
        # fit to points
        self.data = numpy.array([(p.mz, p.intensity) for p in peaklist])
        if not autoAlign:
            self.fitToPoints(self.data, fwhm=fwhm, iterLimit=iterLimit)
        else:
            self.fitToPoints(self.data, fwhm=fwhm, iterLimit=self.scale[1]/2)
            self._alignData()
            self.fitToPoints(self.data, fwhm=fwhm, iterLimit=iterLimit)
    # ----
    
    
    def fitToPoints(self, points, fwhm=0.1, iterLimit=None):
        """Fit modeled profiles to given points.
            points: (numpy array or list) m/z - intensity pairs
            fwhm: (float) defaut fwhm
            iterLimit: (int) maximum number of iterations
        """
        
        self.fwhm = fwhm
        
        # crop points to relevant m/z range
        i1 = processing.findIndex(points, self.mzRange[0], dim=2)
        i2 = processing.findIndex(points, self.mzRange[1], dim=2)
        self.data = numpy.array(points[i1:i2])
        
        # check data
        if len(self.data) == 0:
            self.error = 1
            return
        
        # split data to raster and intensities
        xAxis, yAxis = numpy.hsplit(self.data, 2)
        raster = xAxis.flatten()
        intensities = yAxis.flatten()
        
        # model profiles
        models, exchanged = self._makeModels(raster)
        
        # fit data to model
        self.params = self._leastSquare(intensities, models, iterLimit=iterLimit)
        
        # calc compositions
        self.composition = {}
        self.ncomposition = {}
        f = 1./sum(self.params)
        for i, value in enumerate(self.params):
            self.composition[exchanged[i]] = value
            self.ncomposition[exchanged[i]] = f*value
        
        # calc average exchange
        self.average = 0.
        for x in self.ncomposition:
            self.average += x * self.ncomposition[x]
        
        # calc standard deviation
        self.std = 0.
        err = 0.
        for x in self.ncomposition:
            err += self.ncomposition[x] * (x - self.average)**2
        self.std = math.sqrt(err)
        
        # calc fitted points
        intensities = numpy.sum(models * [[x] for x in self.params], axis=0)
        intensities.shape = (-1,1)
        raster.shape = (-1,1)
        self.model = numpy.concatenate((raster, intensities), axis=1).copy()
    # ----
    
    
    def envelope(self, points=5):
        """Return calculated envelope for current composition."""
        
        # check composition
        if self.composition == None:
            return []
        
        # get peak width
        width = self.fwhm/1.66
        
        # get isotopes for all profiles
        isotopes = []
        for x, abundance in self.composition.items():
            item = "%s(%s)%d(%s)%d" % (self.formula, self._lossFormula, x, self._gainFormula, x)
            compound = objects.compound(item)
            pattern = compound.pattern(fwhm=self.fwhm, charge=self.charge)
            isotopes += [(p[0], p[1]*abundance) for p in pattern]
        
        # make profile from isotopes
        profile = processing.profile(isotopes, fwhm=self.fwhm, points=points)
        
        return profile
    # ----
    
    
    def _validateScale(self, scale):
        """Check if compounds are valid for given scale."""
        
        for x in range(scale[0], scale[1]+1):
            item = "%s(%s)%d(%s)%d" % (self.formula, self._lossFormula, x, self._gainFormula, x)
            compound = objects.compound(item)
            if not compound.isvalid(charge=self.charge):
                return (scale[0], x-1)
        
        return scale
    # ----
    
    
    def _calcMZRange(self):
        """Calculate relevant m/z range for current formula and scale."""
        
        # set additional space
        space = .001
        
        # get m/z range
        item = "%s(%s)%d(%s)%d" % (self.formula, self._lossFormula, self.scale[0], self._gainFormula, self.scale[0])
        compound = objects.compound(item)
        minX = compound.pattern(fwhm=0.5, charge=self.charge)[0][0]
        minX -= minX * space
        
        item = "%s(%s)%d(%s)%d" % (self.formula, self._lossFormula, self.scale[1], self._gainFormula, self.scale[1])
        compound = objects.compound(item)
        maxX = compound.pattern(fwhm=0.5, charge=self.charge)[-1][0]
        maxX += maxX * space
        
        return (minX, maxX)
    # ----
    
    
    def _makeModels(self, raster):
        """Make individual profiles to fit."""
        
        # get peak width
        width = self.fwhm/1.66
        
        # get raster
        rasterMin = raster[0] - self.fwhm
        rasterMax = raster[-1] + self.fwhm
        
        # calculate profiles
        models = []
        exchanged = []
        for x in range(self.scale[0], self.scale[1]+1):
            
            CHECK_FORCE_QUIT()
            
            # generate new formula with H/D exchange
            item = "%s(%s)%d(%s)%d" % (self.formula, self._lossFormula, x, self._gainFormula, x)
            compound = objects.compound(item)
            mz = compound.mz(self.charge)
            
            # check theoretical mass with current m/z raster
            if mz[0] < rasterMax and  mz[1] > rasterMin:
                
                # calculate isotopic pattern
                pattern = compound.pattern(fwhm=self.fwhm, charge=self.charge)
                
                # calculate intensities by adding peak gaussians
                model = numpy.zeros(raster.size, float)
                for peak in pattern:
                    i1 = processing.findIndex(raster, (peak[0]-5*width), dim=1)
                    i2 = processing.findIndex(raster, (peak[0]+5*width), dim=1)
                    for i in range(i1, i2):
                        model[i] += peak[1]*numpy.exp(-1*(pow(raster[i]-peak[0],2))/pow(width,2))
                
                # store data
                if model.any():
                    models.append(model)
                    exchanged.append(x)
        
        # convert profiles to matrix
        models = numpy.array(models)
        
        return models, exchanged
    # ----
    
    
    def _alignData(self):
        """Re-calibrate data using theoretical envelope."""
        
        # check composition
        if not self.composition:
            return
        
        # make theoretical profile
        width = self.fwhm/1.66
        isotopes = []
        for x, abundance in self.composition.items():
            item = "%s(%s)%d(%s)%d" % (self.formula, self._lossFormula, x, self._gainFormula, x)
            compound = objects.compound(item)
            pattern = compound.pattern(fwhm=self.fwhm, charge=self.charge)
            isotopes += [(p[0], p[1]*abundance) for p in pattern]
        profile = processing.profile(isotopes, fwhm=self.fwhm, points=5)
        
        # label peaks in profile
        peaklist = processing.labelScan(profile, pickingHeight=0.95, relThreshold=0.01)
        
        # find peaks within tolerance
        calibrants = []
        tolerance = self.fwhm/1.5
        previous = None
        for peak in peaklist:
            for point in self.data:
                delta = point[0] - peak.mz
                
                if abs(delta) <= tolerance:
                    if previous and previous[0] == peak.mz and previous[1] < point[1]:
                        calibrants[-1] = (point[0], peak.mz)
                    else:
                        calibrants.append((point[0], peak.mz))
                    previous = (peak.mz, point[1])
                
                elif delta > tolerance:
                    break
        
        # calc calibration
        if len(calibrants) > 3:
            model, params, chi = processing.calibration(calibrants, model='quadratic')
        elif len(calibrants) > 1:
            model, params, chi = processing.calibration(calibrants, model='linear')
        else:
            return
        
        # apply calibration to data
        for x in range(len(self.data)):
            self.data[x][0] = model(params, self.data[x][0])
        for x in range(len(self.spectrum)):
            self.spectrum[x][0] = model(params, self.spectrum[x][0])
    # ----
    
    
    def _leastSquare(self, data, models, iterLimit=None, chiLimit=1e-3):
        """Least-square fitting. Adapted from the original code by Konrad Hinsen."""
        
        normf = 100./numpy.max(data)
        data *= normf
        
        params = [50.] * len(models)
        id = numpy.identity(len(params))
        chisq, alpha = self._chiSquare(data, models, params)
        l = 0.001
        
        niter = 0
        while True:
            
            CHECK_FORCE_QUIT()
            
            niter += 1
            delta = solveLinEq(alpha+l*numpy.diagonal(alpha)*id,-0.5*numpy.array(chisq[1]))
            next_params = map(lambda a,b: a+b, params, delta)
            
            for x in range(len(next_params)):
                if next_params[x] < 0.:
                    next_params[x] = 0.
            
            next_chisq, next_alpha = self._chiSquare(data, models, next_params)
            if next_chisq[0] > chisq[0]:
                l = 5.*l
            elif chisq[0] - next_chisq[0] < chiLimit:
                break
            else:
                l = 0.5*l
                params = next_params
                chisq = next_chisq
                alpha = next_alpha
            
            if iterLimit and niter == iterLimit:
                break
        
        next_params /= normf
        
        return next_params
    # ----
    
    
    def _chiSquare(self, data, models, params):
        """Calculate fitting chi-square for current parameter set."""
        
        # calculate differences and chi-square value between calculated and real data
        differences = numpy.sum(models * [[x] for x in params], axis=0) - data
        chisq_value = numpy.sum(differences**2)
        
        # calculate chi-square deriv and alpha
        cycles = len(models)
        chisq_deriv = cycles*[0]
        alpha = numpy.zeros((len(params), len(params)))
        for x in range(len(data)):
            
            deriv = cycles*[0]
            for i in range(cycles):
                p_deriv = cycles*[0]
                p_deriv[i] = models[i][x]
                deriv = map(lambda a,b: a+b, deriv, p_deriv)
            chisq_deriv = map(lambda a,b: a+b, chisq_deriv, map(lambda x,f=differences[x]*2:f*x, deriv))
            
            d = numpy.array(deriv)
            alpha = alpha + d[:,numpy.newaxis]*d
        
        return [chisq_value, chisq_deriv], alpha
    # ----
    
    