"""
(c) 2004 Dave Smith and Fabio Forno
This SOCKSv5 module is mostly based on Dave's Smith Proxy65 Jabber component, who kindly gave me the permission to use it in twibber.
Original source http://proxy65.jabberstudio.org/
"""

from twisted.internet import reactor, protocol, defer, error
from twisted.python import failure, dispatch
import struct, types
pack=struct.pack
unpack=struct.unpack
import sha, sys

STATE_INITIAL = 0
STATE_AUTH    = 1
STATE_REQUEST = 2
STATE_READY   = 3
STATE_AUTH_USERPASS = 4
STATE_LAST    = 5

STATE_CONNECT_PENDING = STATE_LAST + 1

SOCKS5_VER = 0x05

ADDR_IPV4 = 0x01
ADDR_DOMAINNAME = 0x03
ADDR_IPV6 = 0x04

CMD_CONNECT = 0x01
CMD_BIND = 0x02
CMD_UDPASSOC = 0x03

AUTHMECH_ANON = 0x00
AUTHMECH_USERPASS = 0x02
AUTHMECH_INVALID = 0xFF

REPLY_SUCCESS = 0x00
REPLY_GENERAL_FAILURE = 0x01
REPLY_CONN_NOT_ALLOWED = 0x02
REPLY_NETWORK_UNREACHABLE = 0x03
REPLY_HOST_UNREACHABLE = 0x04
REPLY_CONN_REFUSED = 0x05
REPLY_TTL_EXPIRED = 0x06
REPLY_CMD_NOT_SUPPORTED = 0x07
REPLY_ADDR_NOT_SUPPORTED = 0x08

# Events
EVT_CONNECT_DONE="socksv5_connect_done"
EVT_CONNECT_ERROR="socksv5_connect_error"

def spam(*args, **kw): pass

class S5Server(protocol.Protocol):

    def __init__(self):
        """ """
        self.state=STATE_INITIAL
        self.buf=''
        self.bufsize=163484
        self.authMechs=[AUTHMECH_ANON]
        self.addr_types=[ADDR_DOMAINNAME]
        self.commands=[CMD_CONNECT]
        self.finished=0
        
    def _parseNegotiation(self):
        try:
            # Parse out data
            ver, nmethod = unpack('!BB', self.buf[:2])
            methods = unpack('%dB' % nmethod, self.buf[2:nmethod+2])

            # Ensure version is correct
            if ver != 5:
                self.transport.write(pack('!BB', SOCKS5_VER, AUTHMECH_INVALID))
                self.transport.loseConnection()
                return
            
            # Trim off front of the buffer
            self.buf = self.buf[nmethod+2:]
            
            # Check for supported auth mechs
            for m in self.authMechs:
                if m in methods:
                    # Update internal state, according to selected method
                    if m == AUTHMECH_ANON:
                        self.state = STATE_REQUEST
                    elif m == AUTHMECH_USERPASS:
                        self.state = STATE_AUTH_USERPASS
                    # Complete negotiation w/ this method
                    self.transport.write(pack('!BB', SOCKS5_VER, m))
                    return

            # No supported mechs found, notify client and close the connection
            self.transport.write(pack('!BB', SOCKS5_VER, AUTHMECH_INVALID))
            self.transport.loseConnection()
        except struct.error:
            pass
            
    def _parseUserPass(self):
        try:
            # Parse out data
            ver, ulen = unpack('BB', self.buf[:2])
            uname, = unpack('%ds' % ulen, self.buf[2:ulen + 2])
            plen, = unpack('B', self.buf[ulen + 2])
            password, = unpack('%ds' % plen, self.buf[ulen + 3:ulen + 3 + plen])
            # Trim off fron of the buffer
            self.buf = self.buf[3 + ulen + plen:]
            # Fire event to authenticate user
            if self.authenticateUserPass(uname, password):
                # Signal success
                self.state = STATE_REQUEST
                self.transport.write(pack('!BB', SOCKS5_VER, 0x00))
            else:
                # Signal failure
                self.transport.write(pack('!BB', SOCKS5_VER, 0x01))
                self.transport.loseConnection()
        except struct.error:
            pass
    
    def _parseRequest(self):
        try:
            # Parse out data and trim buffer accordingly
            ver, cmd, rsvd, atype = struct.unpack('!BBBB', self.buf[:4])

            # We expect an address of type DOMAIN_NAME
            if atype !=  ADDR_DOMAINNAME:
                self.sendErrorReply(REPLY_ADDR_NOT_SUPPORTED)
                return

            # Deal with addresses
            alen = ord(self.buf[4])
            addr, port = unpack('!%dsH' % alen, self.buf[5:])
            self.buf = self.buf[7 + len(addr):]
                
            # We expect a CONNECT
            if cmd != CMD_CONNECT:
                # Send a not supported error
                self.sendErrorReply(REPLY_CMD_NOT_SUPPORTED)
                return
            
            self._processConnect(addr) #
            
        except struct.error, why:
            return None
    
    def _processConnect(self, addr):
        f=self.factory
        if addr in f.files:
            #send the reply
            answer=pack("!BBBB", SOCKS5_VER, REPLY_SUCCESS, 0x0,
                    ADDR_DOMAINNAME)
            answer+=pack("!B", len(addr))+addr +pack("!H", 0x0)
            self.transport.write(answer)
            self.connectMade(addr)
        else:
            f.publishEvent(EVT_CONNECT_ERROR, addr)
            self.sendErrorReply(REPLY_GENERAL_FAILURE) # XXX correct?
    
    def sendErrorReply(self, errorcode):
        # Any other address types are not supported
        result = pack('!BBBBIH', SOCKS5_VER, errorcode, 0, 1, 0, 0)
        self.transport.write(result)
        self.transport.loseConnection()
        
    def resumeProducing(self):
        """ """
        if self.bytes>self.bufsize: len=self.bufsize
        else: len=self.bytes
        self.bytes-=len
        data=self.file.read(len)
        if data: self.transport.write(data)
        else:
            self.finished=1
            f=self.factory
            self.transport.unregisterProducer()
            self.transport.loseConnection()
            d=f.files.getDeferred(self.addr)
            f.files.delete(self.addr)
            d.callback(self.addr) # notify file sent with success

    def pauseProducing(self):
        pass # XXX I don't know why but sometimes it is called

    def stopProducing(self):
        self.file.close()
    
    def connectMade(self, addr):
        f=self.factory
        f.publishEvent(EVT_CONNECT_DONE, addr)
        if f.autostart:
            self.startStream(addr)
    
    def startStream(self, addr):
        f=self.factory
        filedata=f.files.get(addr)
        self.file=filedata[0]
        self.file.seek(filedata[1])
        self.bytes=filedata[2]
        self.addr=addr
        self.transport.registerProducer(self, 0)
        
    def dataReceived(self, buf):
        if self.state == STATE_READY:
            if self.buf:
                buf=self.buf+buf
                self.buf=""
            return

        self.buf = self.buf + buf
        if self.state == STATE_INITIAL:
            self._parseNegotiation()
        if self.state == STATE_AUTH_USERPASS:
            self._parseUserPass()
        if self.state == STATE_REQUEST:
            self._parseRequest()
       
class FilePool:

    def __init__(self):
        self._files={}
        self._deferreds={}
        
    def makeAddress(self, sid, jid1, jid2):
        return sha.new(sid+jid1+jid2).hexdigest()
    
    def insert(self, addr, file, start=0, end=-1):
        """ """
        d=defer.Deferred()
        if end==-1:
            file.seek(0,2)
            end=file.tell()
            file.seek(0)
        self._files[addr]=(file, 0, end)
        self._deferreds[addr]=d
        return d
    
    def get(self, addr):
        return self._files[addr]

    def getDeferred(self, addr):
        return self._deferreds[addr]
    
    def delete(self, addr):
        f=self._files[addr] #make sure we close the file
        f[0].close()
        del self._files[addr]
        del self._deferreds[addr]
        
    def __contains__(self, addr):
        return addr in self._files
    
class S5Client(protocol.Protocol):
    
    def __init__(self):
        """ """
        self.state=STATE_INITIAL
        self.buf=''
        self.authMechs = [AUTHMECH_ANON]#, AUTHMECH_USERPASS]
        self.failed=1
        
    def connectionMade(self):
        # write the negotiation string
        neg=pack("!B", SOCKS5_VER)+pack("!B", len(self.authMechs))
        for m in self.authMechs:
            neg+=pack("!B", m)
        self.transport.write(neg)
    
    def _parseNegotiation(self):
        try:
            ver,method=unpack("!BB", self.buf[:2])
            self.buf=self.buf[2:]
            if method==AUTHMECH_ANON:
                self._sendRequest()
                self.state=STATE_REQUEST
            elif method==AUTHMECH_USERPASS:
                self._sendUserPass()
                self.state=STATE_AUTH
            else:
                self.transport.loseConnection()
        except struct.error:
            pass # we did not read enough data
    
    def _parseRequest(self):
        try:
            ver,rep,rsv,atype=unpack("!BBBB", self.buf[0:4])
            if rep == REPLY_SUCCESS:
                alen = ord(self.buf[4])
                addr, port = unpack('!%dsH'%(alen,), self.buf[5:alen+7])
                self.buf=self.buf[7+alen:] #4+1+alen+2
                self.state=STATE_READY
                f=self.factory
                f.d_connect_made.callback(addr)
                self.failed=0
            else:
                self.transport.loseConnection()
        except struct.error:
            pass # we did not read enough bytes
        
    def _sendRequest(self):
        f=self.factory
        req=pack("!BBBB", SOCKS5_VER, CMD_CONNECT, 0x0, ADDR_DOMAINNAME)
        addr=sha.new(f.sid+f.initiating_jid+f.target_jid).hexdigest()
        req+=pack("!B", len(addr))+addr+pack("!H", 0x0)
        self.transport.write(req)
        
    def _sendUserPass(self):
        raise NotImplemented
    
    def dataReceived(self, buf):
        #print "--->", ("\%x"*len(buf))%(tuple([ord(c) for c in buf])), "<-"
        self.buf = self.buf + buf
        if self.state == STATE_INITIAL:
            self._parseNegotiation()
        if self.buf and self.state == STATE_AUTH_USERPASS:
            self._parseUserPass()
        if self.buf and self.state == STATE_REQUEST:
            self._parseRequest()
        if self.buf and self.state == STATE_READY:
            self.factory.ofile.write(self.buf)
            self.buf=""
            
    def connectionLost(self, reason):
        f=self.factory
        if self.failed:
            f.ofile.close()
            f.d_file_received.errback(failure.Failure(error.UserError())) #XXX which error to pass
        else:
            f.d_file_received.callback(f.ofile)
    
class S5ClientFactory(protocol.ClientFactory):
    protocol=S5Client
    def __init__(self, sid, initiating_jid, target_jid, ofile=sys.stdout):
        self.sid=sid
        self.initiating_jid=initiating_jid
        self.target_jid=target_jid
        self.ofile=ofile
        self.d_connect_made=defer.Deferred()
        self.d_file_received=defer.Deferred()
        
    def clientConnectionFailed(self, connector, reason):
        self.d_file_received.errback(reason)
 
#XXX the EventDispatcher could be not necessary
class S5ServerFactory(protocol.ServerFactory, dispatch.EventDispatcher):
    protocol=S5Server
    def __init__(self, autostart=1):
        dispatch.EventDispatcher.__init__(self)
        self.autostart=autostart
        self.files=FilePool()
        # this fixes a bug of twisted.python.dispatch which requires at least
        # an handler for each event
        self.registerHandler(EVT_CONNECT_DONE, spam)
        self.registerHandler(EVT_CONNECT_ERROR, spam)
        
    # at this level we should use only the sid
    def queueFile(self, filepath, sid, jid_from, jid_to):
        fp=self.files
        d=fp.insert(fp.makeAddress(sid, jid_from, jid_to), open(filepath,'r'))
        return d
        
def downloadFile(host, port, sid, initiating_jid, target_jid, ofile=sys.stdout):
    """ Download a file """
    f=S5ClientFactory(sid, initiating_jid, target_jid, ofile)
    reactor.connectTCP(host, port, f)
    return f.d_file_received

# ---------------------------------------------------------
# Tests
def _test_server():
    fp=FilePool()
    fp.insert(fp.makeAddress('0','0','0'), open(__file__)) #let's try to send this 
                                                   #source
    f=S5ServerFactory(fp)
    reactor.listenTCP(12345,f)
    reactor.run()

def _test_client():
    f=S5ClientFactory('0','0','0')
    reactor.connectTCP('localhost',12345,f)
    reactor.run()

def _test_download():
    from StringIO import StringIO
    buf=StringIO()
    d=downloadFile('localhost', 12345, '0','0','0', buf)
    d.addCallback(_test_download_result)
    d.addErrback(_test_download_error)
    reactor.run()
    
def _test_download_result(file):
    print file.getvalue()
    reactor.stop()
    
def _test_download_error(failure):
    print failure
    reactor.stop()

if __name__=='__main__':
    import sys
    if sys.argv[1]=='client': _test_client()
    elif sys.argv[1]=='server': _test_server()
    elif sys.argv[1]=='download': _test_download()
