#
# Copyright 2009 Canonical Ltd.
#
# Written by:
#     Gustavo Niemeyer <gustavo.niemeyer@canonical.com>
#     Sidnei da Silva <sidnei.da.silva@canonical.com>
#
# This file is part of the Image Store Proxy.
#
# This program is free software: you can redistribute it and/or modify it 
# under the terms of the GNU General Public License version 3, as published 
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but 
# WITHOUT ANY WARRANTY; without even the implied warranties of 
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR 
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along 
# with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import commands
import logging
import hashlib
import shutil
import time
import os
import re

from urllib import  quote

from imagestore.lib.fetch import fetch, FetchError
from imagestore.lib.service import (
    ServiceTask, ServiceError, ThreadedService, taskHandlerInThread)


CHUNK_SIZE = 1<<16

REMOVE_FILE_DELAY = 10080 # 1 week in seconds

HASH_RE = re.compile("^[a-fA-F0-9]+$")


def verifySHA256(filePath, checksum):
    sha256 = hashlib.sha256()
    handle = open(filePath, 'r')
    try:
        while True:
            data = handle.read(CHUNK_SIZE)
            if not data:
                break
            sha256.update(data)
        return sha256.hexdigest() == checksum
    finally:
        handle.close()


class DownloadServiceError(ServiceError):
    pass


class Progress(object):

    def __init__(self):
        self._abort = False
        self._active = False
        self._partialSize = 0
        self._expectedSize = 0
        self._currentSize = 0
        self._totalSize = 0
        self._lastLog = 0

    def cancel(self):
        self._abort = True

    def wasCancelled(self):
        return self._abort

    def setPartialSize(self, partialSize):
        """Set to the size of a partially downloaded file if any.
        """
        self._partialSize = partialSize

    def setExpectedSize(self, size):
        """Set the expected size of the file to be downloaded.
        """
        self._expectedSize = size
        if self._totalSize == 0:
            self._totalSize = size

    def getCurrentSize(self):
        """Get the current size of the file being downloaded.
        """
        return self._currentSize

    def getTotalSize(self):
        """Get the total size of the file being downloaded.
        """
        return self._totalSize

    def isActive(self):
        return self._active

    def __call__(self, downTotal, downCurrent, upTotal, upCurrent):
        current = None
        total = None

        if not downTotal:
            # XXX This branch is untested.
            if self._expectedSize and downCurrent:
                current, total = self._partialSize + downCurrent, self._expectedSize
            else:
                # XXX This branch is untested as well.
                pass
        else:
            current, total = (self._partialSize + downCurrent,
                              self._partialSize + downTotal)

        if current is not None and total is not None:
            self._currentSize = current
            self._totalSize = total

            now = time.time()
            if self._lastLog < now - 5:
                self._lastLog = now
                logging.debug("Download in progress (current=%d, total=%d)" %
                              (self._currentSize, self._totalSize))

        if self._abort:
            # Returning a non-zero integer value cancels the download.
            return 1


class DownloadServiceTask(ServiceTask):
    pass


class DownloadFileTask(DownloadServiceTask):

    def __init__(self, url, size, sha256):
        self.url = url
        self.size = size
        self.sha256 = sha256
        self.progress = Progress()
        self.progress.setExpectedSize(size)


class DownloadService(ThreadedService):

    def __init__(self, reactor, basePath):
        self._basePath = basePath
        ThreadedService.__init__(self, reactor)

    def start(self):
        limit = time.time() - REMOVE_FILE_DELAY
        for name in os.listdir(self._basePath):
            path = os.path.join(self._basePath, name)
            try:
                mtime = os.path.getmtime(path)
                if mtime < limit:
                    logging.info("Removing old download: %s" % (path,))
                    os.unlink(path)
                else:
                    logging.debug("Preserving old download: %s" % (path,))
            except (IOError, OSError), e:
                logging.error("Couldn't remove old download: %s" % str(e))
        ThreadedService.start(self)

    @taskHandlerInThread(DownloadFileTask)
    def _downloadFile(self, task):
        logging.info("Download starting for %s" % (task.url,))

        def progress(partialSize, totalSize):
            task.progress.setPartialSize(partialSize)
            task.progress.setExpectedSize(totalSize)
            return task.progress

        if not HASH_RE.match(task.sha256):
            raise DownloadServiceError("Invalid checksum: %s" % (task.sha256,))

        localPath = os.path.join(self._basePath, task.sha256)
        for extension in (".tar.gz", ".gz"):
            if task.url.endswith(extension):
                localPath += extension
                break

        try:
            fetch(task.url, size=task.size,
                  resume=True, local_path=localPath,
                  progress=progress)
        except FetchError, e:
            logging.error("Download failed: %s" % (e,))
            raise DownloadServiceError(str(e))
        else:
            logging.info("Download finished.")

        logging.debug("Verifying checksum of downloaded file.")
        if not verifySHA256(localPath, task.sha256):
            message = "Checksum mismatch on downloaded file (expected %s)" \
                      % (task.sha256,)
            logging.error(message)
            os.unlink(localPath)
            raise DownloadServiceError(message)
        logging.info("Checksum of downloaded file matches.")

        if localPath.endswith(".tar.gz"):
            logging.debug("Uncompressing downloaded file (tar.gz).")
            extractPath = localPath + ".extract"
            if os.path.exists(extractPath):
                shutil.rmtree(extractPath)
            os.mkdir(extractPath)
            status, output = commands.getstatusoutput(
                "tar -xzv -f %s -C %s" % (localPath, extractPath))
            if status != 0:
                message = "Uncompression of file failed:\n%s" % (output,)
                logging.error(message)
                os.unlink(localPath)
                raise DownloadServiceError(message)
            extractPathContents = []
            for root, dirs, files in os.walk(extractPath):
                extractPathContents.extend(os.path.join(root, file)
                                           for file in files)
            if len(extractPathContents) != 1:
                raise DownloadServiceError(
                    "Uncompression of file failed: "
                    "tar.gz contains more than one file")
            os.rename(os.path.join(extractPath, extractPathContents[0]),
                      localPath[:-7])
            shutil.rmtree(extractPath)
            os.unlink(localPath)
            logging.debug("Finished uncompressing downloaded file.")
        elif localPath.endswith(".gz"):
            logging.debug("Uncompressing downloaded file (gz).")
            status, output = commands.getstatusoutput("gunzip %s" % localPath)
            if status != 0:
                message = "Uncompression of file failed:\n%s" % (output,)
                logging.error(message)
                os.unlink(localPath)
                raise DownloadServiceError(message)
            logging.debug("Finished uncompressing downloaded file.")
        localPath = localPath.split(".")[0]

        return localPath
