# -*- coding: utf-8 -*-
'''
Module for gathering and managing network information
'''

# Import python libs
import re
import logging

# Import salt libs
import salt.utils
from salt.exceptions import CommandExecutionError


log = logging.getLogger(__name__)


def __virtual__():
    '''
    Only work on POSIX-like systems
    '''
    # Disable on Windows, a specific file module exists:
    if salt.utils.is_windows():
        return False

    return True


def ping(host):
    '''
    Performs a ping to a host

    CLI Example:

    .. code-block:: bash

        salt '*' network.ping archlinux.org
    '''
    cmd = 'ping -c 4 {0}'.format(salt.utils.network.sanitize_host(host))
    return __salt__['cmd.run'](cmd)


# FIXME: Does not work with: netstat 1.42 (2001-04-15) from net-tools
# 1.6.0 (Ubuntu 10.10)
def _netstat_linux():
    '''
    Return netstat information for Linux distros
    '''
    ret = []
    cmd = 'netstat -tulpnea'
    out = __salt__['cmd.run'](cmd, output_loglevel='debug')
    for line in out.splitlines():
        comps = line.split()
        if line.startswith('tcp'):
            ret.append({
                'proto': comps[0],
                'recv-q': comps[1],
                'send-q': comps[2],
                'local-address': comps[3],
                'remote-address': comps[4],
                'state': comps[5],
                'user': comps[6],
                'inode': comps[7],
                'program': comps[8]})
        if line.startswith('udp'):
            ret.append({
                'proto': comps[0],
                'recv-q': comps[1],
                'send-q': comps[2],
                'local-address': comps[3],
                'remote-address': comps[4],
                'user': comps[5],
                'inode': comps[6],
                'program': comps[7]})
    return ret


def _netinfo_openbsd():
    '''
    Get process information for network connections using fstat
    '''
    ret = {}
    _fstat_re = re.compile(
        r'internet(6)? (?:stream tcp 0x\S+ (\S+)|dgram udp (\S+))'
        r'(?: [<>=-]+ (\S+))?$'
    )
    out = __salt__['cmd.run']('fstat', output_loglevel='debug')
    for line in out.splitlines():
        try:
            user, cmd, pid, _, details = line.split(None, 4)
            ipv6, tcp, udp, remote_addr = _fstat_re.match(details).groups()
        except (ValueError, AttributeError):
            # Line either doesn't have the right number of columns, or the
            # regex which looks for address information did not match. Either
            # way, ignore this line and continue on to the next one.
            continue
        if tcp:
            local_addr = tcp
            proto = 'tcp{0}'.format('' if ipv6 is None else ipv6)
        else:
            local_addr = udp
            proto = 'udp{0}'.format('' if ipv6 is None else ipv6)
        if ipv6:
            # IPv6 addresses have the address part enclosed in brackets (if the
            # address part is not a wildcard) to distinguish the address from
            # the port number. Remove them.
            local_addr = ''.join(x for x in local_addr if x not in '[]')

        # Normalize to match netstat output
        local_addr = '.'.join(local_addr.rsplit(':', 1))
        if remote_addr is None:
            remote_addr = '*.*'
        else:
            remote_addr = '.'.join(remote_addr.rsplit(':', 1))

        ret.setdefault(
            local_addr, {}).setdefault(
                remote_addr, {}).setdefault(
                    proto, {}).setdefault(
                        pid, {})['user'] = user
        ret[local_addr][remote_addr][proto][pid]['cmd'] = cmd
    return ret


def _netinfo_freebsd_netbsd():
    '''
    Get process information for network connections using sockstat
    '''
    ret = {}
    # NetBSD requires '-n' to disable port-to-service resolution
    out = __salt__['cmd.run'](
        'sockstat -46 {0} | tail -n+2'.format(
            '-n' if __grains__['kernel'] == 'NetBSD' else ''
        ),
        output_loglevel='debug'
    )
    for line in out.splitlines():
        user, cmd, pid, _, proto, local_addr, remote_addr = line.split()
        local_addr = '.'.join(local_addr.rsplit(':', 1))
        remote_addr = '.'.join(remote_addr.rsplit(':', 1))
        ret.setdefault(
            local_addr, {}).setdefault(
                remote_addr, {}).setdefault(
                    proto, {}).setdefault(
                        pid, {})['user'] = user
        ret[local_addr][remote_addr][proto][pid]['cmd'] = cmd
    return ret


def _ppid():
    '''
    Return a dict of pid to ppid mappings
    '''
    ret = {}
    cmd = 'ps -ax -o pid,ppid | tail -n+2'
    out = __salt__['cmd.run'](cmd, output_loglevel='debug')
    for line in out.splitlines():
        pid, ppid = line.split()
        ret[pid] = ppid
    return ret


def _netstat_bsd():
    '''
    Return netstat information for BSD flavors
    '''
    ret = []
    if __grains__['kernel'] == 'NetBSD':
        for addr_family in ('inet', 'inet6'):
            cmd = 'netstat -f {0} -an | tail -n+3'.format(addr_family)
            out = __salt__['cmd.run'](cmd, output_loglevel='debug')
            for line in out.splitlines():
                comps = line.split()
                entry = {
                    'proto': comps[0],
                    'recv-q': comps[1],
                    'send-q': comps[2],
                    'local-address': comps[3],
                    'remote-address': comps[4]
                }
                if entry['proto'].startswith('tcp'):
                    entry['state'] = comps[5]
                ret.append(entry)
    else:
        # Lookup TCP connections
        cmd = 'netstat -p tcp -an | tail -n+3'
        out = __salt__['cmd.run'](cmd, output_loglevel='debug')
        for line in out.splitlines():
            comps = line.split()
            ret.append({
                'proto': comps[0],
                'recv-q': comps[1],
                'send-q': comps[2],
                'local-address': comps[3],
                'remote-address': comps[4],
                'state': comps[5]})
        # Lookup UDP connections
        cmd = 'netstat -p udp -an | tail -n+3'
        out = __salt__['cmd.run'](cmd, output_loglevel='debug')
        for line in out.splitlines():
            comps = line.split()
            ret.append({
                'proto': comps[0],
                'recv-q': comps[1],
                'send-q': comps[2],
                'local-address': comps[3],
                'remote-address': comps[4]})

    # Add in user and program info
    ppid = _ppid()
    if __grains__['kernel'] == 'OpenBSD':
        netinfo = _netinfo_openbsd()
    elif __grains__['kernel'] in ('FreeBSD', 'NetBSD'):
        netinfo = _netinfo_freebsd_netbsd()
    for idx in range(len(ret)):
        local = ret[idx]['local-address']
        remote = ret[idx]['remote-address']
        proto = ret[idx]['proto']
        try:
            # Make a pointer to the info for this connection for easier
            # reference below
            ptr = netinfo[local][remote][proto]
        except KeyError:
            continue
        # Get the pid-to-ppid mappings for this connection
        conn_ppid = dict((x, y) for x, y in ppid.iteritems() if x in ptr)
        try:
            # Master pid for this connection will be the pid whose ppid isn't
            # in the subset dict we created above
            master_pid = next(iter(
                x for x, y in conn_ppid.iteritems() if y not in ptr
            ))
        except StopIteration:
            continue
        ret[idx]['user'] = ptr[master_pid]['user']
        ret[idx]['program'] = '/'.join((master_pid, ptr[master_pid]['cmd']))
    return ret


def netstat():
    '''
    Return information on open ports and states

    .. note::
        On BSD minions, the output contains PID info (where available) for each
        netstat entry, fetched from sockstat/fstat output.

    .. versionchanged:: 2014.1.4
        Added support for OpenBSD, FreeBSD, and NetBSD

    CLI Example:

    .. code-block:: bash

        salt '*' network.netstat
    '''
    if __grains__['kernel'] == 'Linux':
        return _netstat_linux()
    elif __grains__['kernel'] in ('OpenBSD', 'FreeBSD', 'NetBSD'):
        return _netstat_bsd()
    raise CommandExecutionError('Not yet supported on this platform')


def active_tcp():
    '''
    Return a dict containing information on all of the running TCP connections

    CLI Example:

    .. code-block:: bash

        salt '*' network.active_tcp
    '''
    return salt.utils.network.active_tcp()


def traceroute(host):
    '''
    Performs a traceroute to a 3rd party host

    CLI Example:

    .. code-block:: bash

        salt '*' network.traceroute archlinux.org
    '''
    ret = []
    if not salt.utils.which('traceroute'):
        log.info('This minion does not have traceroute installed')
        return ret

    cmd = 'traceroute {0}'.format(salt.utils.network.sanitize_host(host))

    out = __salt__['cmd.run'](cmd)

    # Parse version of traceroute
    cmd2 = 'traceroute --version'
    out2 = __salt__['cmd.run'](cmd2)
    try:
        # Linux traceroute version looks like:
        #   Modern traceroute for Linux, version 2.0.19, Dec 10 2012
        # Darwin and FreeBSD traceroute version looks like: Version 1.4a12+[FreeBSD|Darwin]

        traceroute_version_raw = re.findall(r'.*[Vv]ersion (\d+)\.([\w\+]+)\.*(\w*)', out2)[0]
        log.debug('traceroute_version_raw: {0}'.format(traceroute_version_raw))
        traceroute_version = []
        for t in traceroute_version_raw:
            try:
                traceroute_version.append(int(t))
            except ValueError:
                traceroute_version.append(t)

        if len(traceroute_version) < 3:
            traceroute_version.append(0)

        log.debug('traceroute_version: {0}'.format(traceroute_version))

    except IndexError:
        traceroute_version = [0, 0, 0]

    for line in out.splitlines():
        if ' ' not in line:
            continue
        if line.startswith('traceroute'):
            continue

        if 'Darwin' in str(traceroute_version[1]) or 'FreeBSD' in str(traceroute_version[1]):
            try:
                traceline = re.findall(r'\s*(\d*)\s+(.*)\s+\((.*)\)\s+(.*)$', line)[0]
            except IndexError:
                traceline = re.findall(r'\s*(\d*)\s+(\*\s+\*\s+\*)', line)[0]

            log.debug('traceline: {0}'.format(traceline))
            delays = re.findall(r'(\d+\.\d+)\s*ms', str(traceline))

            try:
                if traceline[1] == '* * *':
                    result = {
                        'count': traceline[0],
                        'hostname': '*'
                    }
                else:
                    result = {
                        'count': traceline[0],
                        'hostname': traceline[1],
                        'ip': traceline[2],
                    }
                    for x in range(0, len(delays)):
                        result['ms{0}'.format(x + 1)] = delays[x]
            except IndexError:
                result = {}

        elif (traceroute_version[0] >= 2 and traceroute_version[2] >= 14
                or traceroute_version[0] >= 2 and traceroute_version[1] > 0):
            comps = line.split('  ')
            if comps[1] == '* * *':
                result = {
                    'count': int(comps[0]),
                    'hostname': '*'}
            else:
                result = {
                    'count': int(comps[0]),
                    'hostname': comps[1].split()[0],
                    'ip': comps[1].split()[1].strip('()'),
                    'ms1': float(comps[2].split()[0]),
                    'ms2': float(comps[3].split()[0]),
                    'ms3': float(comps[4].split()[0])}
        else:
            comps = line.split()
            result = {
                'count': comps[0],
                'hostname': comps[1],
                'ip': comps[2],
                'ms1': comps[4],
                'ms2': comps[6],
                'ms3': comps[8],
                'ping1': comps[3],
                'ping2': comps[5],
                'ping3': comps[7]}

        ret.append(result)

    return ret


def dig(host):
    '''
    Performs a DNS lookup with dig

    CLI Example:

    .. code-block:: bash

        salt '*' network.dig archlinux.org
    '''
    cmd = 'dig {0}'.format(salt.utils.network.sanitize_host(host))
    return __salt__['cmd.run'](cmd)


def arp():
    '''
    Return the arp table from the minion

    CLI Example:

    .. code-block:: bash

        salt '*' network.arp
    '''
    ret = {}
    out = __salt__['cmd.run']('arp -an')
    for line in out.splitlines():
        comps = line.split()
        if len(comps) < 4:
            continue
        ret[comps[3]] = comps[1].strip('(').strip(')')
    return ret


def interfaces():
    '''
    Return a dictionary of information about all the interfaces on the minion

    CLI Example:

    .. code-block:: bash

        salt '*' network.interfaces
    '''
    return salt.utils.network.interfaces()


def hw_addr(iface):
    '''
    Return the hardware address (a.k.a. MAC address) for a given interface

    CLI Example:

    .. code-block:: bash

        salt '*' network.hw_addr eth0
    '''
    return salt.utils.network.hw_addr(iface)

# Alias hwaddr to preserve backward compat
hwaddr = hw_addr


def subnets():
    '''
    Returns a list of subnets to which the host belongs

    CLI Example:

    .. code-block:: bash

        salt '*' network.subnets
    '''
    return salt.utils.network.subnets()


def in_subnet(cidr):
    '''
    Returns True if host is within specified subnet, otherwise False.

    CLI Example:

    .. code-block:: bash

        salt '*' network.in_subnet 10.0.0.0/16
    '''
    return salt.utils.network.in_subnet(cidr)


def ip_addrs(interface=None, include_loopback=False, cidr=None):
    '''
    Returns a list of IPv4 addresses assigned to the host. 127.0.0.1 is
    ignored, unless 'include_loopback=True' is indicated. If 'interface' is
    provided, then only IP addresses from that interface will be returned.
    Providing a CIDR via 'cidr="10.0.0.0/8"' will return only the addresses
    which are within that subnet.

    CLI Example:

    .. code-block:: bash

        salt '*' network.ip_addrs
    '''
    addrs = salt.utils.network.ip_addrs(interface=interface,
                                        include_loopback=include_loopback)
    if cidr:
        return [i for i in addrs if salt.utils.network.in_subnet(cidr, [i])]
    else:
        return addrs

ipaddrs = ip_addrs


def ip_addrs6(interface=None, include_loopback=False):
    '''
    Returns a list of IPv6 addresses assigned to the host. ::1 is ignored,
    unless 'include_loopback=True' is indicated. If 'interface' is provided,
    then only IP addresses from that interface will be returned.

    CLI Example:

    .. code-block:: bash

        salt '*' network.ip_addrs6
    '''
    return salt.utils.network.ip_addrs6(interface=interface,
                                        include_loopback=include_loopback)

ipaddrs6 = ip_addrs6


def get_hostname():
    '''
    Get hostname

    CLI Example:

    .. code-block:: bash

        salt '*' network.get_hostname
    '''

    #cmd='hostname  -f'
    #return __salt__['cmd.run'](cmd)
    from socket import gethostname
    return gethostname()


def mod_hostname(hostname):
    '''
    Modify hostname

    CLI Example:

    .. code-block:: bash

        salt '*' network.mod_hostname   master.saltstack.com
    '''
    if hostname is None:
        return False

    #1.use shell command hostname
    hostname = hostname
    cmd1 = 'hostname {0}'.format(hostname)

    __salt__['cmd.run'](cmd1)

    #2.modify /etc/hosts hostname
    f = open('/etc/hosts', 'r')
    str_hosts = f.read()
    f.close()
    list_hosts = str_hosts.splitlines()
    cmd2 = '127.0.0.1\t\tlocalhost.localdomain\t\tlocalhost\t\t{0}'.format(hostname)
    #list_hosts[0]=cmd2

    for k in list_hosts:
        if k.startswith('127.0.0.1'):
            num = list_hosts.index(k)
            list_hosts[num] = cmd2

    hostfile = '\n'.join(list_hosts)
    f = open('/etc/hosts', 'w')
    f.write(hostfile)
    f.close()

    #3.modify /etc/sysconfig/network
    f = open('/etc/sysconfig/network', 'r')
    str_network = f.read()
    list_network = str_network.splitlines()
    cmd = 'HOSTNAME={0}'.format(hostname)
    for k in list_network:
        if k.startswith('HOSTNAME'):
            num = list_network.index(k)
            list_network[num] = cmd
    networkfile = '\n'.join(list_network)
    f = open('/etc/sysconfig/network', 'w')
    f.write(networkfile)
    f.close()
    return True
