/*****************************************************************
* Unipro UGENE - Integrated Bioinformatics Suite
* Copyright (C) 2008,2009 Unipro, Russia (http://ugene.unipro.ru)
* All Rights Reserved
* 
*     This source code is distributed under the terms of the
*     GNU General Public License. See the files COPYING and LICENSE
*     for details.
*****************************************************************/

#include "MSAUtils.h"

#include <core_api/DNAAlphabet.h>
#include <gobjects/DNASequenceObject.h>
#include <gobjects/MAlignmentObject.h>
#include <util_text/TextUtils.h>

namespace GB2 {

void MSAUtils::updateConsensus(const MAlignment& msa, QByteArray& cons, MSAConsensusType ctype) {
    LRegion r(0, msa.getLength());
    updateConsensus(msa, r, cons, ctype);
}

void MSAUtils::updateConsensus(const MAlignment& msa, const LRegion& region, QByteArray& cons, MSAConsensusType ctype) {
    QList<LRegion> l;
    l.append(region);
    updateConsensus(msa, l, cons, ctype);
}

void MSAUtils::updateConsensus(const MAlignment& msa, const QList<LRegion>& regions, QByteArray& cons, MSAConsensusType ctype) {
    if (msa.isEmpty()) {
        return;
    }
    int aliLen = msa.getLength();
    if (cons.length()!=aliLen) {
        cons.resize(aliLen);
    }
    foreach(const LRegion& r, regions) {
        for (int i = r.startPos, n = r.endPos(); i < n ; i++) {
            cons[i] = getConsensusChar(msa, i, ctype);
        }
    }
}

char MSAUtils::getConsensusChar(const MAlignment& msa, int pos, MSAConsensusType ctype) {
    if (ctype == MSAConsensusType_Simple || (ctype == MSAConsensusType_ClustalW && !msa.alphabet->isAmino())) {
        //todo: use '+' if there is some conservatism
        char  defChar = (ctype == MSAConsensusType_Simple) ? MAlignment_GapChar : ' ';
        char pc = msa.alignedSeqs.first().sequence[pos];
        if (pc == MAlignment_GapChar) {
            pc = defChar;
        }
        for (int s = 1, nSeq = msa.getNumSequences(); s < nSeq; s++) {
            char c = msa.alignedSeqs[s].sequence[pos];
            if (c != pc) {
                pc = defChar;
                break;
            }
        }
        char res = (pc == defChar) ? defChar : (ctype == MSAConsensusType_Simple ? pc : '*');
        return res;
    } else if (ctype == MSAConsensusType_ClustalW) {
        /* From ClustalW doc:
        '*' indicates positions which have a single, fully conserved residue
        ':' indicates that one of the following 'strong' groups is fully conserved:
        STA, NEQK, NHQK, NDEQ, QHRK, MILV, MILF, HY, FYW, 
        '.' indicates that one of the following 'weaker' groups is fully conserved:
        CSA, ATV, SAG, STNK, STPA, SGND, SNDEQK, NDEQHK, NEQHRK, FVLIM, HFY
        */
        static QByteArray strongGroups[] = {"STA", "NEQK", "NHQK", "NDEQ", "QHRK", "MILV", "MILF", "HY", "FYW"};
        static QByteArray weakGroups[]   = {"CSA", "ATV", "SAG", "STNK", "STPA", "SGND", "SNDEQK", "NDEQHK", "NEQHRK", "FVLIM", "HFY"};
        static int maxStrongGroupLen = 4;
        static int maxWeakGroupLen = 6;

        QByteArray currentGroup; //TODO: optimize 'currentGroup' related code!
        for (int s = 0, nSeq = msa.getNumSequences(); s < nSeq; s++) {
            char c = msa.alignedSeqs[s].sequence[pos];
            if (!currentGroup.contains(c)) {
                currentGroup.append(c);
            }
        }
        char consChar = MAlignment_GapChar;
        if (currentGroup.size() == 1) {
            consChar = (currentGroup[0] == MAlignment_GapChar) ? ' ' : '*';
        } else  {
            bool ok = false;
            int currentLen = currentGroup.length();
            const char* currentGroupData = currentGroup.data();
            //check strong groups
            if (currentLen <= maxStrongGroupLen) {
                for (int sgi=0, sgn = sizeof(strongGroups) / sizeof(QByteArray); sgi < sgn && !ok; sgi++) {
                    bool matches = true;
                    const QByteArray& sgroup = strongGroups[sgi];
                    for (int j=0; j < currentLen && matches; j++) {
                        char c = currentGroupData[j];
                        matches = sgroup.contains(c);
                    }
                    ok = matches;
                }
                if (ok) {
                    consChar = ':';
                }
            } 

            //check weak groups
            if (!ok && currentLen <= maxWeakGroupLen) {
                for (int wgi=0, wgn = sizeof(weakGroups) / sizeof(QByteArray); wgi < wgn && !ok; wgi++) {
                    bool matches = true;
                    const QByteArray& wgroup = weakGroups[wgi];
                    for (int j=0; j < currentLen && matches; j++) {
                        char c = currentGroupData[j];
                        matches = wgroup.contains(c);
                    }
                    ok = matches;
                }
                if (ok) {
                    consChar = '.';
                }
            } 
            //use default
            if (!ok) {
                consChar = ' ';
            }
        } //amino
        return consChar;
    } else {
        assert(ctype == MSAConsensusType_Jalview);
        int cnt = 0;
        char ch = MAlignment_GapChar;
        getConsensusCharAndCount(msa, pos, ch, cnt);
        return ch;
    }
}

void MSAUtils::getConsensusCharAndCount(const MAlignment& msa, int pos, char& ch, int& cnt) {
    QVector<QPair<int, char> > freqs(32);
    int nSeq = msa.getNumSequences();
    for (int seq = 0; seq < nSeq; seq++) {
        uchar c = (uchar)msa.charAt(seq, pos);
        if (c >= 'A' && c <= 'Z') {
            int idx = c - 'A';
            freqs[idx].first++;
            freqs[idx].second = c;
        }
    }
    qSort(freqs);
    int p1 = freqs[freqs.size()-1].first;
    int p2 = freqs[freqs.size()-2].first;
    if (p1 == 0 || (p1 == 1 && nSeq > 1)) {
        ch = MAlignment_GapChar;
        cnt = 0;
    } else {
        int c1 = freqs[freqs.size()-1].second;
        ch = p2 == p1 ? '+' : c1;
        cnt = p1;
    }
}

QString MSAUtils::getConsensusPercentTip(const MAlignment& msa, int pos, int minReportPercent, int maxReportChars) {
    QVector<QPair<int, char> > freqs(32);
    assert(pos>=0 && pos < msa.getLength());
    int nSeq = msa.getNumSequences();
    assert(nSeq > 0); if (nSeq == 0) return QString();
    for (int seq = 0; seq < nSeq; seq++) {
        uchar c = (uchar)msa.charAt(seq, pos);
        if (c >= 'A' && c <= 'Z') {
            int idx = c - 'A';
            freqs[idx].first++;
            freqs[idx].second = c;
        }
    }
    qSort(freqs);
    double percentK = 100.0 / nSeq;
    
    QString res = "<table cellspacing=7>";
    int i = 0;
    for(; i < 32; i++) {
        int p = freqs[freqs.size()-i-1].first;
        double percent = p * percentK;
        if (percent < minReportPercent || percent < 1.0 / nSeq) {
            break;
        }
        if (i == maxReportChars) {
            break;
        }
        int c = freqs[freqs.size()-i-1].second;
        res = res + "<tr><td><b>" + QChar(c) + "</b></td>";
        res = res + "<td align=right>" + QString::number(percent, 'f', 1) + "%</td>";
        res = res + "<td align=right>" + QString::number(p) + "</td>";
        res = res + "</tr>";
    }
    if (i == 0) {
        return "";
    }
    if (i == maxReportChars) {
        res+="<tr><td colspan=3>...</td></tr>";
    }
    res+="</table>";
    return res;
}

bool MSAUtils::equalsIgnoreGaps(const QByteArray& seq, int startPos, const QByteArray& pat) {
    int sLen = seq.size();
    int pLen = pat.size();
    for (int i = startPos, j = 0; i  < sLen && j < pLen; i++, j++) {
        char c1 = seq[i];
        char c2 = pat[j];
        while(c1 == MAlignment_GapChar && ++i < sLen) {
            c1 = seq[i];
        }
        if (c1 != c2) {
            return false;
        }
    }
    return true;
}


MAlignment MSAUtils::seq2ma( const QList<GObject*>& list, QString& err ) {
    MAlignment ma(MA_OBJECT_NAME);
    foreach(GObject* obj, list) {
        DNASequenceObject* dnaObj = qobject_cast<DNASequenceObject*>(obj);
        const DNASequence seq = dnaObj->getDNASequence();
        if (ma.alphabet == NULL) {
            ma.alphabet = dnaObj->getAlphabet();
        } else {
            ma.alphabet = DNAAlphabet::deriveCommonAlphabet(ma.alphabet, dnaObj->getAlphabet());
            if (ma.alphabet == NULL) {
                err = tr("Sequences have different alphabets.");
                break;
            }
        }
        if (seq.length() > MAX_ALI_LEN) {
            err = tr("Sequence is too large for alignment: %1").arg(seq.getName());
            break;
        }
        ma.alignedSeqs.append(MAlignmentItem(dnaObj->getGObjectName(), seq.seq));
    }

    if (err.isEmpty()) {
        ma.normalizeModel();
    } else {
        ma.clear();
    }

    return ma;
}

QList<DNASequence> MSAUtils::ma2seq(const MAlignment& ma, bool trimGaps) {
    QList<DNASequence> lst;
    QBitArray gapCharMap = TextUtils::createBitMap(MAlignment_GapChar);
    foreach(const MAlignmentItem& i, ma.alignedSeqs) {
        DNASequence s(i.name, i.sequence, ma.alphabet);
        if (trimGaps) {
            int newLen = TextUtils::remove(s.seq.data(), s.length(), gapCharMap);
            s.seq.resize(newLen);
        }
        lst << s;
    }
    return lst;
}

quint32 MSAUtils::packConsensusCharsToInt(const MAlignment& msa, int pos, const int* mask4, bool gapsAffectPercents) {
    QVector<QPair<int, char> > freqs(32);
    int numNoGaps = 0;
    int nSeq = msa.getNumSequences();
    for (int seq = 0; seq < nSeq; seq++) {
        uchar c = (uchar)msa.charAt(seq, pos);
        if (c >= 'A' && c <= 'Z') {
            int idx = c - 'A';
            freqs[idx].first++;
            freqs[idx].second = c;
            numNoGaps++;
        }
    }
    qSort(freqs);
    if (!gapsAffectPercents && numNoGaps == 0) {
        return 0xE0E0E0E0;//'4' in masks, '0' in values
    }
    int res = 0;
    double percentK = 100.0 / (gapsAffectPercents ? nSeq : numNoGaps);
    for (int i=0;i < 4;i++) {
        int p = int(freqs[freqs.size()-i-1].first * percentK);
        quint32 rangeBits = (p >= mask4[0]) ?  0 : 
                            (p >= mask4[1]) ?  1 :
                            (p >= mask4[2]) ?  2 : 
                            (p >= mask4[3]) ?  3 : 4;
        quint32 charVal = rangeBits == 4  ? 'A' : quint32(freqs[freqs.size()-i-1].second);
        quint32 maskedVal = (rangeBits << 5) | (charVal - 'A'); //3 bits for range, 5 for symbol
        assert(maskedVal <= 255);
        res = res | (maskedVal << (8 * i));
    }
    return res;
}

void MSAUtils::unpackConsensusCharsFromInt(quint32 val, char* charVal, int* maskPos) {
    for (int i = 0; i < 4; i++) {
        quint32 byteVal = (val >> (i * 8)) & 0xFF;
        maskPos[i] = byteVal >> 5; 
        charVal[i] = (maskPos[i] < 4) ? (byteVal & 0x1F) + 'A' : 0;
    }
}


void MSAUtils::getColumnFreqs(const MAlignment& ma, int pos, QVector<int>& freqsByChar, int& nonGapChars) {
    assert(freqsByChar.size() == 256);
    assert(ma.isNormalized());
    freqsByChar.fill(0);
    nonGapChars = 0;
    int* freqs = freqsByChar.data();
    int nSeq = ma.getNumSequences();
    for (int seq = 0; seq < nSeq; seq++) {
        uchar c = (uchar)ma.charAt(seq, pos);
        freqs[c]++;
        if (c!=MAlignment_GapChar) {
            nonGapChars++;
        }
    }
}

}//namespace
