#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2009-2013 Red Hat, Inc.
#
# Authors:
# Thomas Woerner <twoerner@redhat.com>
# Jiri Popelka <jpopelka@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

from gi.repository import GObject
import sys
sys.modules['gobject'] = GObject

import argparse
import dbus
import os

from firewall.client import FirewallClient
from firewall.errors import *

def __fail(msg=None):
    if msg:
        print(msg)
    sys.exit(2)

def __parse_port(value):
    try:
        (port, proto) = value.split("/")
    except Exception as msg:
        __fail(msg)
    return (port, proto)

def __parse_forward_port(value):
    port = None
    protocol = None
    toport = None
    toaddr = None
    args = value.split(":")
    for arg in args:
        try:
            (opt,val) = arg.split("=")
            if opt == "port":
                port = val
            elif opt == "proto":
                protocol = val
            elif opt == "toport":
                toport = val
            elif opt == "toaddr":
                toaddr = val
        except:
            __fail("invalid forward port arg '%s'" % (arg))
    if not port:
        __fail("missing port")
    if not protocol:
        __fail("missing protocol")
    if not (toport or toaddr):
        __fail("missing destination")
    return (port, protocol, toport, toaddr)

def __list_all(fw, zone):
    interfaces = fw.getInterfaces(zone)
    services = fw.getServices(zone)
    ports = fw.getPorts(zone)
    forward_ports = fw.getForwardPorts(zone)
    icmp_blocks = fw.getIcmpBlocks(zone)

    print(zone if zone != "" else fw.getDefaultZone())
    print("  interfaces: " + " ".join(interfaces))
    print("  services: " + " ".join(services))
    print("  ports: " + " ".join(["%s/%s" % (port[0], port[1]) for port in ports]))
    print("  forward-ports: " + "\n\t".join(["port=%s:proto=%s:toport=%s:toaddr=%s" % (port, protocol, toport, toaddr) for (port, protocol, toport, toaddr) in forward_ports]))
    print("  icmp-blocks: " + " ".join(icmp_blocks))

parser = argparse.ArgumentParser(usage="see firewall-cmd man page",
                                 add_help=False)

parser_group_standalone = parser.add_mutually_exclusive_group()
parser_group_standalone.add_argument("-h", "--help",
                                     action="store_true")
parser_group_standalone.add_argument("-V", "--version", action="store_true")
parser_group_standalone.add_argument("--state", action="store_true")
parser_group_standalone.add_argument("--reload", action="store_true")
parser_group_standalone.add_argument("--complete-reload", action="store_true")
parser_group_standalone.add_argument("--panic-on", action="store_true")
parser_group_standalone.add_argument("--panic-off", action="store_true")
parser_group_standalone.add_argument("--query-panic", action="store_true")
parser_group_standalone.add_argument("--get-default-zone", action="store_true")
parser_group_standalone.add_argument("--set-default-zone", metavar="<zone>")
parser_group_standalone.add_argument("--get-zones", action="store_true")
parser_group_standalone.add_argument("--get-services", action="store_true")
parser_group_standalone.add_argument("--get-icmptypes", action="store_true")
parser_group_standalone.add_argument("--get-active-zones", action="store_true")
parser_group_standalone.add_argument("--get-zone-of-interface", metavar="<iface>")
parser_group_standalone.add_argument("--list-all-zones", action="store_true")

parser.add_argument("--permanent", action="store_true")
parser.add_argument("--zone", default="", metavar="<zone>")
parser.add_argument("--timeout", default=0, type=int, metavar="<seconds>")

parser_group_zone = parser.add_mutually_exclusive_group()
parser_group_zone.add_argument("--add-interface", metavar="<iface>")
parser_group_zone.add_argument("--remove-interface", metavar="<iface>")
parser_group_zone.add_argument("--query-interface", metavar="<iface>")
parser_group_zone.add_argument("--change-interface", "--change-zone", metavar="<iface>")
parser_group_zone.add_argument("--list-interfaces", action="store_true")
parser_group_zone.add_argument("--add-service", metavar="<service>", action='append')
parser_group_zone.add_argument("--remove-service", metavar="<zone>", action='append')
parser_group_zone.add_argument("--query-service", metavar="<zone>")
parser_group_zone.add_argument("--add-port", metavar="<port>", action='append')
parser_group_zone.add_argument("--remove-port", metavar="<port>", action='append')
parser_group_zone.add_argument("--query-port", metavar="<port>")
parser_group_zone.add_argument("--add-masquerade", action="store_true")
parser_group_zone.add_argument("--remove-masquerade", action="store_true")
parser_group_zone.add_argument("--query-masquerade", action="store_true")
parser_group_zone.add_argument("--add-icmp-block", metavar="<icmptype>", action='append')
parser_group_zone.add_argument("--remove-icmp-block", metavar="<icmptype>", action='append')
parser_group_zone.add_argument("--query-icmp-block", metavar="<icmptype>")
parser_group_zone.add_argument("--add-forward-port", metavar="<port>", action='append')
parser_group_zone.add_argument("--remove-forward-port", metavar="<port>", action='append')
parser_group_zone.add_argument("--query-forward-port", metavar="<port>")
parser_group_zone.add_argument("--list-services", action="store_true")
parser_group_zone.add_argument("--list-ports", action="store_true")
parser_group_zone.add_argument("--list-icmp-blocks", action="store_true")
parser_group_zone.add_argument("--list-forward-ports", action="store_true")
parser_group_zone.add_argument("--list-all", action="store_true")

parser.add_argument("--direct", action="store_true")

parser_direct = parser.add_mutually_exclusive_group()
parser_direct.add_argument("--passthrough", nargs=argparse.REMAINDER,
                    metavar=("{ ipv4 | ipv6 | eb }", "<args>"))
parser_direct.add_argument("--add-chain", nargs=3,
                    metavar=("{ ipv4 | ipv6 | eb }", "<table>", "<chain>"))
parser_direct.add_argument("--remove-chain", nargs=3,
                        metavar=("{ ipv4 | ipv6 | eb }", "<table>", "<chain>"))
parser_direct.add_argument("--query-chain", nargs=3,
                        metavar=("{ ipv4 | ipv6 | eb }", "<table>", "<chain>"))
parser_direct.add_argument("--get-chains", nargs=2,
                        metavar=("{ ipv4 | ipv6 | eb }", "<table>"))
parser_direct.add_argument("--add-rule", nargs=argparse.REMAINDER,
                        metavar=("{ ipv4 | ipv6 | eb }", "<table> <chain> <priority> <args>"))
parser_direct.add_argument("--remove-rule", nargs=argparse.REMAINDER,
                        metavar=("{ ipv4 | ipv6 | eb }", "<table> <chain> <args>"))
parser_direct.add_argument("--query-rule", nargs=argparse.REMAINDER,
                        metavar=("{ ipv4 | ipv6 | eb }", "<table> <chain> <args>"))
parser_direct.add_argument("--get-rules", nargs=3,
                        metavar=("{ ipv4 | ipv6 | eb }", "<table>", "<chain>"))

a = parser.parse_args()

options_standalone = a.help or a.version or \
                     a.state or a.reload or a.complete_reload or \
                     a.panic_on or a.panic_off or a.query_panic or \
                     a.get_default_zone or a.set_default_zone or \
                     a.get_active_zones or a.get_zone_of_interface or \
                     a.list_all_zones

options_config = a.get_zones or a.get_services or a.get_icmptypes

options_zone_interfaces = a.list_interfaces or a.change_interface or \
                   a.add_interface or a.remove_interface or a.query_interface

options_zone_action_action = \
                    a.add_service or a.remove_service or a.query_service or \
                             a.add_port or a.remove_port or a.query_port or \
           a.add_icmp_block or a.remove_icmp_block or a.query_icmp_block or \
          a.add_forward_port or a.remove_forward_port or a.query_forward_port

options_zone_adapt_query = \
            a.add_masquerade or a.remove_masquerade or a.query_masquerade or \
 a.list_services or a.list_ports or a.list_icmp_blocks or a.list_forward_ports

options_zone = a.zone or options_zone_interfaces or a.list_all or a.timeout or \
               options_zone_action_action or options_zone_adapt_query

options_permanent = a.permanent or options_config or a.zone or \
                    options_zone_action_action or options_zone_adapt_query

options_direct = a.direct or a.passthrough or \
           a.add_chain or a.remove_chain or a.query_chain or a.get_chains or \
                 a.add_rule or a.remove_rule or a.query_rule or a.get_rules

# Check various impossible combinations of options

if not (options_standalone or options_zone or \
        options_permanent or options_direct):
    __fail(parser.format_usage() + "No option specified.")

if options_standalone and (options_zone or options_permanent or options_direct):
    __fail(parser.format_usage() +
           "Can't use stand-alone options with other options.")
if options_direct and (options_zone or options_permanent):
    __fail(parser.format_usage() +
           "Can't use 'direct' options with other options.")

if options_config and options_zone:
    __fail(parser.format_usage() +
           "Wrong usage of --get-zones | --get-services | --get-icmptypes.")

if a.timeout and not (a.add_service or a.add_port or a.add_icmp_block or \
                      a.add_forward_port or a.add_masquerade):
    __fail(parser.format_usage() + "Wrong --timeout usage")

if a.permanent:
    if a.timeout != 0:
        __fail(parser.format_usage() +
               "Can't specify timeout for permanent action.")
    if options_config and not a.zone:
        pass
    elif options_zone_action_action or options_zone_adapt_query:
        pass
    else:
        __fail(parser.format_usage() + "Wrong --permanent usage.")

if a.help:
    os.system("man firewall-cmd")
    sys.exit(0)

zone = a.zone
try:
    fw = FirewallClient(quiet=True)
    if fw.connected == False:
        if a.state:
            print "not running"
        else:
            print "FirewallD is not running"
        sys.exit(NOT_RUNNING)

    if a.permanent:
        if a.get_zones:
            zones = fw.config().listZones()
            l = [fw.config().getZone(z).get_property("name") for z in zones]
            if len(l) > 0:
                print(" ".join(l))
        elif a.get_services:
            services = fw.config().listServices()
            l = [fw.config().getService(s).get_property("name") for s in services]
            if len(l) > 0:
                print(" ".join(l))
        elif a.get_icmptypes:
            icmptypes = fw.config().listIcmpTypes()
            l = [fw.config().getIcmpType(i).get_property("name") for i in icmptypes]
            if len(l) > 0:
                print(" ".join(l))
        else:
            if zone == "":
                zone = fw.getDefaultZone()
            fw_zone = fw.config().getZoneByName(zone)
            fw_settings = fw_zone.getSettings()

            # service
            if a.list_services:
                l = fw_settings.getServices()
                if len(l) > 0:
                    print(" ".join(l))
                sys.exit(0)
            elif a.add_service:
                for s in a.add_service:
                    fw_settings.addService(s)
            elif a.remove_service:
                for s in a.remove_service:
                    fw_settings.removeService(s)
            elif a.query_service:
                sys.exit(not a.query_service in fw_settings.getServices())

            # port
            elif a.list_ports:
                l = fw_settings.getPorts()
                if len(l) > 0:
                    print(" ".join(["%s/%s" % (port[0], port[1]) for port in l]))
                sys.exit(0)
            elif a.add_port:
                for port_proto in a.add_port:
                    (port, proto) = __parse_port(port_proto)
                    fw_settings.addPort(port, proto)
            elif a.remove_port:
                for port_proto in a.remove_port:
                    (port, proto) = __parse_port(port_proto)
                    fw_settings.removePort(port, proto)
            elif a.query_port:
                (port, proto) = __parse_port(a.query_port)
                sys.exit(not (port, proto) in fw_settings.getPorts())

            # masquerade
            elif a.add_masquerade:
                fw_settings.setMasquerade(True)
            elif a.remove_masquerade:
                fw_settings.setMasquerade(False)
            elif a.query_masquerade:
                sys.exit(not fw_settings.getMasquerade())

            # forward port
            elif a.list_forward_ports:
                l = fw_settings.getForwardPorts()
                if len(l) > 0:
                    print("\n".join(["port=%s:proto=%s:toport=%s:toaddr=%s" % (port, protocol, toport, toaddr) for (port, protocol, toport, toaddr) in l]))
                sys.exit(0)
            elif a.add_forward_port:
                for fp in a.add_forward_port:
                    (port, protocol, toport, toaddr) = __parse_forward_port(fp)
                    fw_settings.addForwardPort(port, protocol, toport, toaddr)
            elif a.remove_forward_port:
                for fp in a.remove_forward_port:
                    (port, protocol, toport, toaddr) = __parse_forward_port(fp)
                    fw_settings.removeForwardPort(port, protocol, toport, toaddr)
            elif a.query_forward_port:
                (port, protocol, toport, toaddr) = __parse_forward_port(a.query_forward_port)
                sys.exit(not fw_settings.queryForwardPort(port, protocol, toport, toaddr))

            # block icmp
            elif a.list_icmp_blocks:
                l = fw_settings.getIcmpBlocks()
                if len(l) > 0:
                    print(" ".join(l))
                sys.exit(0)
            elif a.add_icmp_block:
                for ib in a.add_icmp_block: 
                    fw_settings.addIcmpBlock(ib)
            elif a.remove_icmp_block:
                for ib in a.remove_icmp_block:
                    fw_settings.removeIcmpBlock(ib)
            elif a.query_icmp_block:
                sys.exit(not a.query_icmp_block in fw_settings.getIcmpBlocks())

            fw_zone.update(fw_settings)

    elif a.version:
        print(fw.get_property("version"))
        sys.exit(0)
    elif a.state:
        state = fw.get_property("state")
        if state == "RUNNING":
            print "running"
        else:
            print "not running"
            sys.exit(NOT_RUNNING)
    elif a.reload:
        fw.reload()
    elif a.complete_reload:
        fw.complete_reload()
    elif a.passthrough:
        print(fw.passthrough(a.passthrough[0], a.passthrough[1:]))
    elif a.add_chain:
        fw.addChain(a.add_chain[0], a.add_chain[1], a.add_chain[2])
    elif a.remove_chain:
        fw.removeChain(a.remove_chain[0], a.remove_chain[1], a.remove_chain[2])
    elif a.query_chain:
        sys.exit(not fw.queryChain(a.query_chain[0], a.query_chain[1],
                                   a.query_chain[2]))
    elif a.get_chains:
        print(" ".join(fw.getChains(a.get_chains[0], a.get_chains[1])))
    elif a.add_rule:
        fw.addRule(a.add_rule[0], a.add_rule[1], a.add_rule[2], int(a.add_rule[3]),
                   a.add_rule[4:])
    elif a.remove_rule:
        fw.removeRule(a.remove_rule[0], a.remove_rule[1], a.remove_rule[2],
                      a.remove_rule[3:])
    elif a.query_rule:
        sys.exit(not fw.queryRule(a.query_rule[0], a.query_rule[1],
                                  a.query_rule[2], a.query_rule[3:]))
    elif a.get_rules:
        rules = fw.getRules(a.get_rules[0], a.get_rules[1], a.get_rules[2])
        for rule in rules:
            print(" ".join(rule)) 
    elif a.get_default_zone:
        print(fw.getDefaultZone())
    elif a.set_default_zone:
        fw.setDefaultZone(a.set_default_zone)
    elif a.get_zones:
        print(" ".join(fw.getZones()))
    elif a.get_active_zones:
        zones = fw.getActiveZones()
        for zone in zones:
            print("%s: %s" % (zone, " ".join(zones[zone])))
    elif a.get_zone_of_interface:
        try:
            print(fw.getZoneOfInterface(a.get_zone_of_interface))
        except:
            pass
    elif a.get_services:
        l = fw.listServices()
        if len(l) > 0:
            print(" ".join(l))
    elif a.get_icmptypes:
        l = fw.listIcmpTypes()
        if len(l) > 0:
            print(" ".join(l))

    # panic
    elif a.panic_on:
        fw.enablePanicMode()
    elif a.panic_off:
        fw.disablePanicMode()
    elif a.query_panic:
        sys.exit(not fw.queryPanicMode())

    # interface
    elif a.list_interfaces:
        l = fw.getInterfaces(zone)
        if len(l) > 0:
            print(" ".join(l))
    elif a.add_interface:
        fw.addInterface(zone, a.add_interface)
    elif a.change_interface:
        fw.changeZone(zone, a.change_interface)
    elif a.remove_interface:
        fw.removeInterface(zone, a.remove_interface)
    elif a.query_interface:
        sys.exit(not fw.queryInterface(zone, a.query_interface))

    # service
    elif a.list_services:
        l = fw.getServices(zone)
        if len(l) > 0:
            print(" ".join(l))
    elif a.add_service:
        for s in a.add_service:
            fw.addService(zone, s, a.timeout)
    elif a.remove_service:
        for s in a.remove_service:
            fw.removeService(zone, s)
    elif a.query_service:
        sys.exit(not fw.queryService(zone, a.query_service))

    # port
    elif a.list_ports:
        l = fw.getPorts(zone)
        if len(l) > 0:
            print(" ".join(["%s/%s" % (port[0], port[1]) for port in l]))
    elif a.add_port:
        for port_proto in a.add_port:
            (port, proto) = __parse_port(port_proto)
            fw.addPort(zone, port, proto, a.timeout)
    elif a.remove_port:
        for port_proto in a.remove_port:
            (port, proto) = __parse_port(port_proto)
            fw.removePort(zone, port, proto)
    elif a.query_port:
        (port, proto) = __parse_port(a.query_port)
        sys.exit(not fw.queryPort(zone, port, proto))

    # masquerade
    elif a.add_masquerade:
        fw.addMasquerade(zone, a.timeout)
    elif a.remove_masquerade:
        fw.removeMasquerade(zone)
    elif a.query_masquerade:
        sys.exit(not fw.queryMasquerade(zone))

    # forward port
    elif a.list_forward_ports:
        l = fw.getForwardPorts(zone)
        if len(l) > 0:
            print("\n".join(["port=%s:proto=%s:toport=%s:toaddr=%s" % (port, protocol, toport, toaddr) for (port, protocol, toport, toaddr) in l]))

    elif a.add_forward_port:
        for fp in a.add_forward_port:
            (port, protocol, toport, toaddr) = __parse_forward_port(fp)
            fw.addForwardPort(zone, port, protocol, toport, toaddr, a.timeout)
    elif a.remove_forward_port:
        for fp in a.remove_forward_port:
            (port, protocol, toport, toaddr) = __parse_forward_port(fp)
            fw.removeForwardPort(zone, port, protocol, toport, toaddr)
    elif a.query_forward_port:
        (port, protocol, toport, toaddr) = __parse_forward_port(a.query_forward_port)
        sys.exit(not fw.queryForwardPort(zone, port, protocol,
                                         toport, toaddr))

    # block icmp
    elif a.list_icmp_blocks:
        l = fw.getIcmpBlocks(zone)
        if len(l) > 0:
            print(" ".join(l))
    elif a.add_icmp_block:
        for ib in a.add_icmp_block:
            fw.addIcmpBlock(zone, ib, a.timeout)
    elif a.remove_icmp_block:
        for ib in a.remove_icmp_block:
            fw.removeIcmpBlock(zone, ib)
    elif a.query_icmp_block:
        sys.exit(not fw.queryIcmpBlock(zone, a.query_icmp_block))


    # list all
    elif a.list_all:
        __list_all(fw, zone)


    # list everything
    elif a.list_all_zones:
        for zone in fw.getZones():
            __list_all(fw, zone)
            print("")


except dbus.DBusException as e:
    if "NotAuthorizedException" in str(e):
        print ("Authorization failed.")
        print ("Make sure polkit agent is running or run firewall-cmd as superuser.")
        sys.exit(NOT_AUTHORIZED)
    else:
        try:
            code = FirewallError.get_code(e.message)
        except:
            code = UNKNOWN_ERROR
            print("Error: %s" % e)
        else:
            if code in [ ALREADY_ENABLED, NOT_ENABLED, ZONE_ALREADY_SET ]:
                print("Warning: %s" % e.message)
                sys.exit(0)
            else:
                print("Error: %s" % e.message)
        sys.exit(code)

sys.exit(0)
