# -------------------------------------------------------------------------
#     Copyright (C) 2005-2010 Martin Strohalm <www.mmass.org>

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 3 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE.TXT in the
#     main directory of the program
# -------------------------------------------------------------------------

# load libs
import threading
import wx

# load modules
from ids import *
import config
import mwx
import images
import mspy
import doc


# FLOATING PANEL WITH MATCH TOOLS
# -------------------------------

class panelMatch(wx.MiniFrame):
    """Match tools."""
    
    def __init__(self, parent, tool='errors'):
        wx.MiniFrame.__init__(self, parent, -1, 'Match Data', size=(400, 300), style=wx.DEFAULT_FRAME_STYLE & ~ (wx.RESIZE_BOX | wx.MAXIMIZE_BOX))
        
        self.parent = parent
        self.processing = None
        
        self.currentTool = tool
        self.currentDataType = None
        self.currentData = None
        self.currentPeaklist = None
        self.currentSummary = None
        self.currentErrors = []
        self.currentReferences = []
        
        # make gui items
        self.makeGUI()
        wx.EVT_CLOSE(self, self.onClose)
        
        # select default tool
        self.onToolSelected(tool=self.currentTool)
    # ----
    
    
    def makeGUI(self):
        """Make gui notebook."""
        
        # make toolbar
        toolbar = self.makeToolbar()
        
        # make panel
        self.makeErrorCanvas()
        self.makeSummaryList()
        gauge = self.makeGaugePanel()
        
        # pack element
        self.mainSizer = wx.BoxSizer(wx.VERTICAL)
        self.mainSizer.Add(toolbar, 0, wx.EXPAND, 0)
        self.mainSizer.Add(self.errorCanvas, 1, wx.EXPAND, 0)
        self.mainSizer.Add(self.summaryList, 1, wx.EXPAND|wx.ALL, mwx.LISTCTRL_NO_SPACE)
        self.mainSizer.Add(gauge, 0, wx.EXPAND, 0)
        
        # hide gauge
        self.mainSizer.Hide(1)
        self.mainSizer.Hide(2)
        self.mainSizer.Hide(3)
        
        # fit layout
        self.SetMinSize((100,100))
        self.Layout()
        self.mainSizer.Fit(self)
        self.SetSizer(self.mainSizer)
        self.SetMinSize(self.GetSize())
    # ----
    
    
    def makeToolbar(self):
        """Make toolbar."""
        
        # init toolbar
        panel = mwx.bgrPanel(self, -1, images.lib['bgrToolbarNoBorder'], size=(-1, mwx.TOOLBAR_HEIGHT))
        
        # make buttons
        self.errors_butt = wx.BitmapButton(panel, ID_matchErrors, images.lib['matchErrorsOff'], size=(mwx.TOOLBAR_TOOLSIZE), style=wx.BORDER_NONE)
        self.errors_butt.SetToolTip(wx.ToolTip("Error plot"))
        self.errors_butt.Bind(wx.EVT_BUTTON, self.onToolSelected)
        
        self.summary_butt = wx.BitmapButton(panel, ID_matchSummary, images.lib['matchSummaryOff'], size=(mwx.TOOLBAR_TOOLSIZE), style=wx.BORDER_NONE)
        self.summary_butt.SetToolTip(wx.ToolTip("Match summary"))
        self.summary_butt.Bind(wx.EVT_BUTTON, self.onToolSelected)
        
        # make match fields
        tolerance_label = wx.StaticText(panel, -1, "Tolerance:")
        tolerance_label.SetFont(wx.SMALL_FONT)
        
        self.tolerance_value = wx.TextCtrl(panel, -1, str(config.match['tolerance']), size=(60, -1), validator=mwx.validator('floatPos'))
        
        self.unitsDa_radio = wx.RadioButton(panel, -1, "Da", style=wx.RB_GROUP)
        self.unitsDa_radio.SetFont(wx.SMALL_FONT)
        self.unitsDa_radio.SetValue(True)
        self.unitsDa_radio.Bind(wx.EVT_RADIOBUTTON, self.onUnitsChanged)
        
        self.unitsPpm_radio = wx.RadioButton(panel, -1, "ppm")
        self.unitsPpm_radio.SetFont(wx.SMALL_FONT)
        self.unitsPpm_radio.Bind(wx.EVT_RADIOBUTTON, self.onUnitsChanged)
        self.unitsPpm_radio.SetValue((config.match['units'] == 'ppm'))
        
        self.ignoreCharge_check = wx.CheckBox(panel, -1, "Ignore charge")
        self.ignoreCharge_check.SetFont(wx.SMALL_FONT)
        self.ignoreCharge_check.SetValue(config.match['ignoreCharge'])
        
        self.match_butt = wx.Button(panel, -1, "Match", size=(-1, mwx.SMALL_BUTTON_HEIGHT))
        self.match_butt.SetFont(wx.SMALL_FONT)
        self.match_butt.Bind(wx.EVT_BUTTON, self.onMatch)
        
        self.calibrate_butt = wx.Button(panel, -1, "Calibrate", size=(-1, mwx.SMALL_BUTTON_HEIGHT))
        self.calibrate_butt.SetFont(wx.SMALL_FONT)
        self.calibrate_butt.Bind(wx.EVT_BUTTON, self.onCalibrate)
        
        # pack elements
        sizer = wx.BoxSizer(wx.HORIZONTAL)
        sizer.AddSpacer(mwx.TOOLBAR_LSPACE)
        sizer.Add(self.errors_butt, 0, wx.ALIGN_CENTER_VERTICAL)
        sizer.Add(self.summary_butt, 0, wx.ALIGN_CENTER_VERTICAL|wx.LEFT, mwx.BUTTON_SIZE_CORRECTION)
        sizer.AddSpacer(20)
        sizer.Add(tolerance_label, 0, wx.ALIGN_CENTER_VERTICAL|wx.RIGHT, 5)
        sizer.Add(self.tolerance_value, 0, wx.ALIGN_CENTER_VERTICAL)
        sizer.AddSpacer(10)
        sizer.Add(self.unitsDa_radio, 0, wx.ALIGN_CENTER_VERTICAL|wx.RIGHT, 5)
        sizer.Add(self.unitsPpm_radio, 0, wx.ALIGN_CENTER_VERTICAL)
        sizer.AddSpacer(20)
        sizer.Add(self.ignoreCharge_check, 0, wx.ALIGN_CENTER_VERTICAL)
        sizer.AddStretchSpacer()
        sizer.AddSpacer(20)
        sizer.Add(self.match_butt, 0, wx.ALIGN_CENTER_VERTICAL|wx.RIGHT, 10)
        sizer.Add(self.calibrate_butt, 0, wx.ALIGN_CENTER_VERTICAL)
        sizer.AddSpacer(mwx.TOOLBAR_RSPACE)
        
        mainSizer = wx.BoxSizer(wx.VERTICAL)
        mainSizer.Add(sizer, 1, wx.EXPAND)
        
        panel.SetSizer(mainSizer)
        mainSizer.Fit(panel)
        
        return panel
    # ----
    
    
    def makeErrorCanvas(self):
        """Make plot canvas and set defalt parameters."""
        
        # init canvas
        self.errorCanvas = mspy.plot.canvas(self, size=(-1, 220), style=mwx.PLOTCANVAS_STYLE_PANEL)
        self.errorCanvas.draw(mspy.plot.container([]))
        
        # set default params
        self.errorCanvas.setProperties(xLabel='m/z')
        self.errorCanvas.setProperties(yLabel='error in %s' % config.match['units'])
        self.errorCanvas.setProperties(showZero=True)
        self.errorCanvas.setProperties(showLegend=False)
        self.errorCanvas.setProperties(showPosBar=True)
        self.errorCanvas.setProperties(posBarHeight=6)
        self.errorCanvas.setProperties(showGel=False)
        self.errorCanvas.setProperties(showCurTracker=True)
        self.errorCanvas.setProperties(checkLimits=True)
        self.errorCanvas.setProperties(autoScaleY=False)
        self.errorCanvas.setProperties(xPosDigits=config.main['mzDigits'])
        self.errorCanvas.setProperties(yPosDigits=2)
        self.errorCanvas.setProperties(reverseDrawing=True)
        
        axisFont = wx.Font(config.spectrum['axisFontSize'], wx.SWISS, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL, 0)
        self.errorCanvas.setProperties(axisFont=axisFont)
        
        self.errorCanvas.draw(mspy.plot.container([]))
    # ----
    
    
    def makeSummaryList(self):
        """Make match summary list."""
        
        # init list
        self.summaryList = mwx.sortListCtrl(self, -1, size=(621, 200), style=mwx.LISTCTRL_STYLE_SINGLE)
        self.summaryList.SetFont(wx.SMALL_FONT)
        self.summaryList.setAltColour(mwx.LISTCTRL_ALTCOLOUR)
        
        # make columns
        self.summaryList.InsertColumn(0, "parameter", wx.LIST_FORMAT_LEFT)
        self.summaryList.InsertColumn(1, "value", wx.LIST_FORMAT_LEFT)
        
        # set column widths
        for col, width in enumerate((250,350)):
            self.summaryList.SetColumnWidth(col, width)
    # ----
    
    
    def makeGaugePanel(self):
        """Make processing gauge."""
        
        panel = wx.Panel(self, -1)
        
        # make elements
        self.gauge = mwx.gauge(panel, -1)
        
        # pack elements
        mainSizer = wx.BoxSizer(wx.VERTICAL)
        mainSizer.Add(self.gauge, 0, wx.EXPAND|wx.ALL, mwx.GAUGE_SPACE)
        
        # fit layout
        mainSizer.Fit(panel)
        panel.SetSizer(mainSizer)
        
        return panel
    # ----
    
    
    def onClose(self, evt):
        """Destroy this frame."""
        
        # check processing
        if self.processing != None:
            wx.Bell()
            return
        
        # close self
        self.Destroy()
    # ----
    
    
    def onToolSelected(self, evt=None, tool=None):
        """Selected tool."""
        
        # get the tool
        if evt != None:
            tool = 'errors'
            if evt and evt.GetId() == ID_matchErrors:
                tool = 'errors'
            elif evt and evt.GetId() == ID_matchSummary:
                tool = 'summary'
        
        # set current tool
        self.currentTool = tool
        
        # hide panels
        self.mainSizer.Hide(1)
        self.mainSizer.Hide(2)
        
        # set icons off
        self.errors_butt.SetBitmapLabel(images.lib['matchErrorsOff'])
        self.summary_butt.SetBitmapLabel(images.lib['matchSummaryOff'])
        
        # set panel
        if tool == 'errors':
            #self.SetTitle("Match Error Plot")
            self.mainSizer.Show(1)
            self.errors_butt.SetBitmapLabel(images.lib['matchErrorsOn'])
            
        elif tool == 'summary':
            #self.SetTitle("Match Summary")
            self.mainSizer.Show(2)
            self.summary_butt.SetBitmapLabel(images.lib['matchSummaryOn'])
        
        # fit layout
        self.SetMinSize((-1,-1))
        self.mainSizer.Fit(self)
        self.Layout()
        size = self.GetSize()
        self.SetSize((size[0]+1,size[1]))
        self.SetSize(size)
        self.SetMinSize(size)
    # ----
    
    
    def onProcessing(self, status=True):
        """Show processing gauge."""
        
        self.gauge.SetValue(0)
        
        if status:
            self.MakeModal(True)
            self.mainSizer.Show(3)
        else:
            self.MakeModal(False)
            self.mainSizer.Hide(3)
            self.processing = None
        
        # fit layout
        self.Layout()
        self.mainSizer.Fit(self)
        try: wx.Yield()
        except: pass
    # ----
    
    
    def onMatch(self, evt=None):
        """Match data to peaklist."""
        
        # check processing
        if self.processing:
            return
        
        # check data
        if self.currentDataType==None or self.currentData==None or self.currentPeaklist==None:
            self.currentSummary = None
            self.currentErrors = []
            self.currentReferences = []
            self.updateErrorCanvas()
            self.updateMatchSummary()
            wx.Bell()
            return
        
        # get params
        if not self.getParams():
            self.currentSummary = None
            self.currentErrors = []
            self.currentReferences = []
            self.updateErrorCanvas()
            self.updateMatchSummary()
            return
        
        # show processing gauge
        self.onProcessing(True)
        self.match_butt.Enable(False)
        self.calibrate_butt.Enable(False)
        
        # do processing
        self.processing = threading.Thread(target=self.runMatch)
        self.processing.start()
        
        # pulse gauge while working
        while self.processing and self.processing.isAlive():
            self.gauge.pulse()
        
        # update gui
        self.parent.updateMatches(self.currentDataType)
        self.updateErrorCanvas()
        self.updateMatchSummary()
        
        # hide processing gauge
        self.onProcessing(False)
        self.match_butt.Enable(True)
        self.calibrate_butt.Enable(True)
    # ----
    
    
    def onCalibrate(self, evt):
        """Use matches for calibration."""
        
        # check references
        if not self.currentReferences:
            wx.Bell()
            return
        
        # show calibration panel
        self.parent.calibrateByMatches(self.currentReferences)
    # ----
    
    
    def onUnitsChanged(self, evt=None):
        """Change units in error plot."""
        
        # get units
        if self.unitsDa_radio.GetValue():
            config.match['units'] = 'Da'
        else:
            config.match['units'] = 'ppm'
        
        # recalc errors
        if config.match['units'] == 'ppm':
            for x in range(len(self.currentErrors)):
                self.currentErrors[x][1] = self.currentErrors[x][1] / (self.currentErrors[x][0] / 1000000)
        elif config.match['units'] == 'Da':
            for x in range(len(self.currentErrors)):
                self.currentErrors[x][1] = self.currentErrors[x][1] * (self.currentErrors[x][0] / 1000000)
        
        # update plot
        self.updateErrorCanvas()
    # ----
    
    
    def setData(self, dataType, data, peaklist, searchInfo=None):
        """Set data."""
        
        # update values
        self.currentDataType = dataType
        self.currentData = data
        self.currentPeaklist = peaklist
        self.currentSearchInfo = searchInfo
        
        self.currentSummary = None
        self.currentErrors = []
        self.currentReferences = []
        
        # set title
        if self.currentDataType == 'digest':
            self.SetTitle('Match Peptides')
        elif self.currentDataType == 'fragment':
            self.SetTitle('Match Fragments')
        elif self.currentDataType == 'compounds':
            self.SetTitle('Match Compounds')
        else:
            self.SetTitle('Match Data')
        
        # clear error canvas and summary
        self.updateErrorCanvas()
        self.updateMatchSummary()
    # ----
    
    
    def getParams(self):
        """Get all params from dialog."""
        
        # try to get values
        try:
            config.match['tolerance'] = float(self.tolerance_value.GetValue())
            
            if self.unitsDa_radio.GetValue():
                config.match['units'] = 'Da'
            else:
                config.match['units'] = 'ppm'
            
            if self.ignoreCharge_check.GetValue():
                config.match['ignoreCharge'] = 1
            else:
                config.match['ignoreCharge'] = 0
            
            return True
            
        except:
            wx.Bell()
            return False
    # ----
    
    
    def runMatch(self):
        """Match data to peaklist."""
        
        # set columns
        if self.currentDataType == 'digest':
            massCol = 2
            chargeCol = 3
            errorCol = 5
            matchObject = doc.match
        elif self.currentDataType == 'fragment':
            massCol = 3
            chargeCol = 4
            errorCol = 6
            matchObject = doc.match
        elif self.currentDataType == 'compounds':
            massCol = 1
            chargeCol = 2
            errorCol = 5
            matchObject = doc.annotation
        
        # clear previous match
        for item in self.currentData:
            item[errorCol] = None
            item[-1] = []
        
        # match data
        self.currentErrors = []
        self.currentReferences = []
        
        digits = '%0.' + `config.main['mzDigits']` + 'f'
        for peak in self.currentPeaklist:
            if (peak.isotope == 0 or peak.charge == None):
                for x, item in enumerate(self.currentData):
                    
                    # check tolerance
                    if (peak.charge==None or config.match['ignoreCharge'] or (peak.charge and item[chargeCol]==abs(peak.charge))):
                        error = mspy.delta(peak.mz, item[massCol], config.match['units'])
                        if abs(error) <= config.match['tolerance']:
                            
                            # create new match object
                            match = matchObject(label='', mz=peak.mz, intensity=peak.intensity, baseline=peak.baseline, theoretical=item[massCol])
                            self.currentData[x][-1].append(match)
                            
                            # errors and references
                            label = 'Peak ' + digits % peak.mz
                            self.currentErrors.append([peak.mz, error])
                            self.currentReferences.append([label, item[massCol], peak.mz])
        
        # show best error only
        for item in self.currentData:
            for match in item[-1]:
                error = match.delta(config.match['units'])
                if item[errorCol] == None or abs(item[errorCol]) > abs(error):
                    item[errorCol] = error
        
        # get match summary
        self.makeMatchSummary()
    # ----
    
    
    def clear(self):
        """Clear all."""
        
        self.currentDataType = None
        self.currentData = None
        self.currentPeaklist = None
        self.currentSummary = None
        self.currentErrors = []
        self.currentReferences = []
        self.updateErrorCanvas()
        self.updateMatchSummary()
    # ----
    
    
    def updateErrorCanvas(self):
        """Update error canvas."""
        
        # make container
        container = mspy.plot.container([])
        
        # make points object
        self.currentErrors.sort()
        points = mspy.plot.points(self.currentErrors, pointColour=(0,255,0), showPoints=True, showLines=False)
        container.append(points)
        
        # make peaklist object
        if self.currentPeaklist:
            peaks = self.makeCurrentPeaklist()
            peaklist = mspy.plot.spectrum(mspy.scan(peaks=peaks), tickColour=(170,170,170), showLabels=False)
            container.append(peaklist)
        
        # set units
        self.errorCanvas.setProperties(yLabel='error in %s' % (config.match['units']))
        
        # draw container
        self.errorCanvas.draw(container)
    # ----
    
    
    def updateMatchSummary(self):
        """Update match summary list."""
        
        # clear previous data and set new
        self.summaryList.DeleteAllItems()
        self.summaryList.setDataMap(self.currentSummary)
        
        # check data
        if not self.currentSummary:
            return
        
        # add new data
        for row, item in enumerate(self.currentSummary):
            self.summaryList.InsertStringItem(row, item[0])
            self.summaryList.SetStringItem(row, 1, str(item[1]))
            self.summaryList.SetItemData(row, row)
        
        # update background
        self.summaryList.updateItemsBackground()
        
        # scroll top
        self.summaryList.EnsureVisible(0)
    # ----
    
    
    def makeCurrentPeaklist(self):
        """Convert peaklist for current error range."""
        
        # get error range
        minY = 0
        maxY = 1
        if self.currentErrors:
            errors = [x[1] for x in self.currentErrors]
            minY = min(errors)
            maxY = max(errors)
            if minY == maxY:
                minY -= minY*0.1
                maxY += maxY*0.1
            minY -= 0.05 * abs(maxY - minY)
        
        # convert peaklist
        peaklist = []
        basePeak = self.currentPeaklist.basePeak
        f = abs(maxY - minY) / (basePeak.intensity - basePeak.baseline)
        for peak in self.currentPeaklist:
            intensity = ((peak.intensity - peak.baseline) * f) + minY
            peaklist.append(mspy.peak(mz=peak.mz, intensity=intensity, baseline=minY))
        
        # convert to mspy.peaklist
        return mspy.peaklist(peaklist)
    # ----
    
    
    def makeMatchSummary(self):
        """Make summary info for current match."""
        
        self.currentSummary = []
        
        # get searched items
        value = '%d' % len(self.currentData)
        if self.currentDataType == 'digest':
            self.currentSummary.append(('Number of peptides searched', value))
        elif self.currentDataType == 'fragment':
            self.currentSummary.append(('Number of fragments searched', value))
        elif self.currentDataType == 'compounds':
            self.currentSummary.append(('Number of compounds searched', value))
        
        # get searched peaks
        sumPeaklist = 0
        for peak in self.currentPeaklist:
            if (peak.isotope == 0 or peak.charge == None):
                sumPeaklist += 1
        value = '%d' % sumPeaklist
        self.currentSummary.append(('Number of peaks searched', value))
        
        # get matched peaks
        sumMatched = 0
        for item in self.currentData:
            if item[-1]:
                sumMatched += 1
        value = '%d (%0.f %s)' % (sumMatched, 100*sumMatched/max(1,sumPeaklist), '%')
        self.currentSummary.append(('Number of peaks matched', value))
        
        # get sequence coverage
        if self.currentDataType == 'digest':
            sumPeptides = []
            for item in self.currentData:
                if item[-1]:
                    sumPeptides.append(item[0])
            coverage = mspy.coverage(sumPeptides, self.currentSearchInfo['sequenceLength'])
            value = '%0.f %s' % (coverage, '%')
            self.currentSummary.append(('Sequence length', self.currentSearchInfo['sequenceLength']))
            self.currentSummary.append(('Sequence coverage', value))
        
        # get ion series
        elif self.currentDataType == 'fragment':
            self.currentSummary.append(('Sequence length', self.currentSearchInfo['sequenceLength']))
            
            series = {}
            for item in self.currentData:
                if item[0][:3] == 'int':
                    continue
                if not item[0] in series:
                    series[item[0]] = []
                if item[-1]:
                    series[item[0]].append(item[1])
                
            for serie in sorted(series.keys()):
                matches = series[serie]
                matches.sort()
                value = ', '.join(str(n) for n in matches)
                label = 'Ion serie "%s" matches' % serie
                self.currentSummary.append((label, value))
    # ----
    

