import re
import time

from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.python.failure import Failure

from txzookeeper.client import ConnectionTimeoutException

from juju.errors import NoConnection, InvalidHost, InvalidUser
from juju.state.security import SecurityPolicyConnection
from juju.state.sshforward import forward_port, ClientTunnelProtocol

from .utils import PortWatcher, get_open_port

SERVER_RE = re.compile("^(\S+):(\d+)$")


def called_aware_deferred_chain(result, d):
    """A deferred chain utility function.

    Unlike the standard chainDefererd, this implementation is aware if the
    the deferred to be chained has already fired. It also does some basic
    value inspection to determine chaining behavior.

    If a non true value is given, the value is not propagated to the chained
    deferred.

    If the chained deferred has already been called, then the value is returned
    and processing stops.

    For all other values that are propagated to the chained deferred, this
    method consumes the value.

    If the value is a failure, the errback method is invoked on the
    chained deferred.

    Else the value is used to invoke the chained deferred's callback.

    @param result the callback value of the current deferred stack.
    @type ANY

    @param d Chain deferred to invoke
    @type C{twisted.internet.defer.Deferred}

    """
    # only propagate actual values.
    if not result:
        return

    # only propagate if next deferred down has not been fired.
    if d.called:
        return result

    # if we propagate we consume the value of the current deferred
    # callback chain.
    if isinstance(result, Failure):
        d.errback(result)
    else:
        d.callback(result)


class SSHClient(SecurityPolicyConnection):
    """
    A ZookeeperClient which will internally handle an SSH tunnel
    to connect to the remote host.
    """

    remote_user = "ubuntu"
    _process = None

    # TODO We'll need things like set_ssh_username and set_ssh_keypath
    # in this method at some point, configuring the forwarding which
    # happens in connect().

    def _internal_connect(self, server=None, timeout=None, share=False):
        """Connect to the remote host provided via an ssh port forward.

        An SSH process is fired with port forwarding established on localhost
        22181, which the zookeeper client connects to.

        @param server: Remote host to connect to, specified as hostname:port
        @type string

        @param timeout: An timeout interval in seconds.
        @type float
        """
        hostname, port = self._parse_servers(server or self._servers)
        started = time.time()

        # A deferred denoting when an error on the ssh tunnel occurs.
        tunnel_error_deferred = Deferred()

        # A deferred denoting when the tunnel has been established.
        tunnel_established_deferred = Deferred()

        # A deferred denoting the client has connected.
        client_connected_deferred = Deferred()

        protocol = ClientTunnelProtocol(self, tunnel_error_deferred)

        # Determine which port we'll be using.
        local_port = get_open_port()

        # Close the tunnel if any errors.
        client_connected_deferred.addErrback(self._eb_tunnel)

        # Setup an ssh process for port forwarding.
        self._process = forward_port(
            self.remote_user, local_port, hostname, int(port),
            process_protocol=protocol, share=share)

        # Zookeeper connect after the port forward is established.
        tunnel_established_deferred.addCallback(
            self._cb_tunnel_established, "localhost:%d" % local_port,
            timeout, started)

        # If the zk connect succeeds, propagate to connect deferred.
        tunnel_established_deferred.addCallback(
            called_aware_deferred_chain, client_connected_deferred)

        # Wait for the port to open.
        port_watcher = PortWatcher("localhost", local_port, timeout)
        port_open_deferred = port_watcher.async_wait()

        # Convert port watching errors to connection timeout.
        def port_watch_error(failure):
            raise ConnectionTimeoutException("could not connect")
        port_open_deferred.addErrback(port_watch_error)

        # On success, the tunnel is established.
        port_open_deferred.addCallback(
            called_aware_deferred_chain, tunnel_established_deferred)

        # If there is an error, then the port watch is done.
        client_connected_deferred.addErrback(port_watcher.stop)

        # Propagate errors if we're not connected.
        tunnel_error_deferred.addErrback(
            called_aware_deferred_chain, client_connected_deferred)
        tunnel_established_deferred.addErrback(
            called_aware_deferred_chain, client_connected_deferred)
        port_open_deferred.addErrback(
            called_aware_deferred_chain, client_connected_deferred)

        return client_connected_deferred

    def _parse_servers(self, servers):
        """Extract a server host and port."""
        match = SERVER_RE.match(servers)
        hostname, port = match.groups()
        return hostname, port

    def _eb_tunnel(self, failure):
        self.close()
        return failure

    def _cb_tunnel_established(self, result, host, timeout, start_time):
        """After the tunnel is established, connect the zookeeper client."""
        value = time.time()
        new_timeout = timeout - (value - start_time)
        if new_timeout <= 0:
            raise ConnectionTimeoutException(
                "could not connect before timeout")
        return super(SSHClient, self).connect(host, new_timeout)

    @inlineCallbacks
    def connect(self, server=None, timeout=30, share=False):
        """Probe ZK is accessible via ssh tunnel, return client on success."""
        until = time.time() + timeout
        num_retries = 0
        while time.time() < until:
            num_retries += 1
            try:
                yield self._internal_connect(
                    server, timeout=until - time.time(), share=share)
            except ConnectionTimeoutException:
                # Reraises implicitly, but with the number of retries
                # (see the outside of this loop); this circumstance
                # would occur if the port watcher timed out before we
                # got anything from the tunnel
                break
            except InvalidHost:
                # No point in retrying if the host itself is invalid
                self.close()
                raise
            except InvalidUser:
                # Or if the user doesn't have a login
                self.close()
                raise
            except NoConnection:
                # Otherwise retry after ssh tunnel forwarding failures
                self.close()
            else:
                returnValue(self)
        self.close()
        # we raise ConnectionTimeoutException (rather than one of our own, with
        # the same meaning) to maintain ZookeeperClient interface
        raise ConnectionTimeoutException(
            "could not connect before timeout after %d retries" % num_retries)

    def close(self):
        """Close the zookeeper connection, and the associated ssh tunnel."""
        super(SSHClient, self).close()
        if self._process is not None:
            self._process.signalProcess("TERM")
            self._process.loseConnection()
            self._process = None
