#! /usr/bin/env python
# Copyright (C) 2005 Red Hat 
# see file 'COPYING' for use and warranty information
#
# Audit2allow is a rewrite of prior perl script.
#
# Based off original audit2allow perl script: which credits
#    newrules.pl, Copyright (C) 2001 Justin R. Smith (jsmith@mcs.drexel.edu)
#    2003 Oct 11: Add -l option by Yuichi Nakamura(ynakam@users.sourceforge.jp)
#
#    This program is free software; you can redistribute it and/or
#    modify it under the terms of the GNU General Public License as
#    published by the Free Software Foundation; either version 2 of
#    the License, or (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA     
#                                        02111-1307  USA
#
#  
import commands, sys, os, pwd, string, getopt, re, selinux
class serule:
	def __init__(self, type, source, target, seclass):
		self.type=type
		self.source=source
		self.target=target
		self.seclass=seclass
		self.avcinfo={}
	def add(self, avc):
		for a in avc[0]:
			if a not in self.avcinfo.keys():
				self.avcinfo[a]=[]

			self.avcinfo[a].append(avc[1:])

	def getAccess(self):
		if len(self.avcinfo.keys()) == 1:
			for i in self.avcinfo.keys():
				return i
		else:
			keys=self.avcinfo.keys()
			keys.sort()
			ret="{"
			for i in keys:
				ret=ret + " " + i				
			ret=ret+" }"
			return ret
	def out(self, verbose=0):
		ret=""
		ret=ret+"%s %s %s:%s %s;" % (self.type, self.source, self.gettarget(), self.seclass, self.getAccess())
		if verbose:
			keys=self.avcinfo.keys()
			keys.sort()
			for i in keys:
				for x in self.avcinfo[i]:
					ret=ret+"\n\t#TYPE=AVC  MSG=%s  " % x[0]
					if len(x[1]):
						ret=ret+"COMM=%s  " % x[1]
					if len(x[2]):
						ret=ret+"NAME=%s  " % x[2]
					ret=ret + " : " + i 
		return ret
		
	def gettarget(self):
		if self.source == self.target:
			return "self"
		else:
			return self.target
	
class seruleRecords:
	def __init__(self, input, last_reload=0, verbose=0, te_ind=0):
		self.last_reload=last_reload
		self.seRules={}
		self.seclasses={}
		self.types=[]
		self.roles=[]
		self.load(input, te_ind)
		
	def warning(self, error):
		sys.stderr.write("%s: " % sys.argv[0])
		sys.stderr.write("%s\n" % error)
		sys.stderr.flush()

	def load(self, input, te_ind=0):
		VALID_CMDS=("allow", "dontaudit", "auditallow", "role")
		
		avc=[]
		found=0
		line = input.readline()
		if te_ind:
			while line:
				rec=line.split()
				if len(rec) and rec[0] in VALID_CMDS:
					self.add_terule(line)
				line = input.readline()
					
		else:
			while line:
				rec=line.split()
				for i in rec:
					if i=="avc:" or i=="message=avc:":
						found=1
					else:
						avc.append(i)
				if found:
					self.add(avc)
					found=0
					avc=[]
				line = input.readline()
				

	def get_target(self, i, rule):
		target=[]
		if rule[i][0] == "{":
			for t in rule[i].split("{"):
				if len(t):
					target.append(t)
			i=i+1
			for s in rule[i:]:
				if s.find("}") >= 0:
					for s1 in s.split("}"):
						if len(s1):
							target.append(s1)
						i=i+1
						return (i, target)

				target.append(s)
				i=i+1
		else:
			if rule[i].find(";") >= 0:
				for s1 in rule[i].split(";"):
					if len(s1):
						target.append(s1)
			else:
				target.append(rule[i])

		i=i+1
		return (i, target)

	def rules_split(self, rules):
		(idx, target ) = self.get_target(0, rules)
		(idx, subject) = self.get_target(idx, rules)
		return (target, subject)

	def add_terule(self, rule):
		rc = rule.split(":")
		rules=rc[0].split()
		type=rules[0]
		if type == "role":
			print type
		(sources, targets) = self.rules_split(rules[1:])
		rules=rc[1].split()
		(seclasses, access) = self.rules_split(rules)
		for scon in sources:
			for tcon in targets:
				for seclass in seclasses:
					self.add_rule(type, scon, tcon, seclass,access)
		
	def add_rule(self, rule_type, scon, tcon, seclass, access, msg="", comm="", name=""):
		self.add_seclass(seclass, access)
		self.add_type(tcon)
		self.add_type(scon)
		if (type, scon, tcon, seclass) not in self.seRules.keys():
			self.seRules[(rule_type, scon, tcon, seclass)]=serule(rule_type, scon, tcon, seclass)
				
		self.seRules[(rule_type, scon, tcon, seclass)].add((access, msg, comm, name ))

	def add(self,avc):
		scon=""
		tcon=""
		seclass=""
		comm=""
		name=""
		msg=""
		access=[]
		if "security_compute_sid" in avc:
			return
		
		if "granted" in avc:
			if "load_policy" in avc and self.last_reload:
				self.seRules={}
			return
		try:
			for i in range (0, len(avc)):
				if avc[i]=="{":
					i=i+1
					while i<len(avc) and avc[i] != "}":
						access.append(avc[i])
						i=i+1
					continue
			
				t=avc[i].split('=')
				if len(t) < 2:
					continue
				if t[0]=="scontext":
					context=t[1].split(":")
					scon=context[2]
					srole=context[1]
					continue
				if t[0]=="tcontext":
					context=t[1].split(":")
					tcon=context[2]
					trole=context[1]
					continue
				if t[0]=="tclass":
					seclass=t[1]
					continue
				if t[0]=="comm":
					comm=t[1]
					continue
				if t[0]=="name":
					name=t[1]
					continue
				if t[0]=="msg":
					msg=t[1]
					continue

			if scon=="" or tcon =="" or seclass=="":
				return
		except IndexError, e:
			self.warning("Bad AVC Line: %s" % avc)
			return
			
		self.add_role(srole)
		self.add_role(trole)
		self.add_rule("allow", scon, tcon, seclass, access, msg, comm, name)

	def add_seclass(self,seclass, access):
		if seclass not in self.seclasses.keys():
				self.seclasses[seclass]=[]
		for a in access:
			if a not in self.seclasses[seclass]:
				self.seclasses[seclass].append(a)
				
	def add_role(self,role):
		if role not in self.roles:
				self.roles.append(role)

	def add_type(self,type):
		if type not in self.types:
				self.types.append(type)

	def gen_module(self, module):
		return "module %s 1.0;" % module

	def gen_requires(self):
		self.roles.sort()
		self.types.sort()
		keys=self.seclasses.keys()
		keys.sort()
		rec="\n\nrequire {\n"
		if len(self.roles) > 0:
			for i in self.roles:
				rec += "\trole %s; \n" % i
			rec += "\n" 

		for i in keys:
			access=self.seclasses[i]
			if len(access) > 1:
				access.sort()
				rec += "\tclass %s {" % i
				for a in access:
					rec += " %s" % a
				rec += " }; \n"
			else:
				rec += "\tclass %s %s;\n" % (i, access[0])
				
		rec += "\n" 
			
		for i in self.types:
			rec += "\ttype %s; \n" % i
		rec += " };\n\n\n"
		return rec
	
	def out(self, require=0, module=""):
		rec=""
		if len(self.seRules.keys())==0:
		       raise(ValueError("No AVC messages found."))
		if module != "":
			rec += self.gen_module(module)
			rec += self.gen_requires()
		else:
			if requires:
				rec+=self.gen_requires()

		keys=self.seRules.keys()
		keys.sort()
		for i in keys:
			rec += self.seRules[i].out(verbose)+"\n"
		return rec

if __name__ == '__main__':

	def get_mls_flag():
		if selinux.is_selinux_mls_enabled():
			return "-M"
		else:
			return ""

	def usage(msg=""):
		print 'audit2allow [-adhilrv] [-t file ] [ -f fcfile ] [-i <inputfile> ] [[-m|-M] <modulename> ] [-o <outputfile>]\n\
		-a, --all        read input from audit and message log, conflicts with -i\n\
		-d, --dmesg      read input from output of /bin/dmesg\n\
		-h, --help       display this message\n\
		-i, --input      read input from <inputfile> conflicts with -a\n\
		-l, --lastreload read input only after last \"load_policy\"\n\
		-m, --module     generate module/require output <modulename> \n\
		-M               generate loadable module package, conflicts with -o\n\
		-o, --output     append output to <outputfile>, conflicts with -M\n\
		-r, --requires   generate require output \n\
		-t, --tefile     Indicates input is Existing Type Enforcement file\n\
		-f, --fcfile     Existing Type Enforcement file, requires -M\n\
		-v, --verbose    verbose output\n\
		'
		if msg != "":
			print msg
		sys.exit(1)
		
	def errorExit(error):
		sys.stderr.write("%s: " % sys.argv[0])
		sys.stderr.write("%s\n" % error)
		sys.stderr.flush()
		sys.exit(1)

	#
	# 
	#
	try:
		last_reload=0
		input=sys.stdin
		output=sys.stdout
		module=""
		requires=0
		verbose=0
		auditlogs=0
		buildPP=0
		input_ind=0
		output_ind=0
		te_ind=0

		fc_file=""
		gopts, cmds = getopt.getopt(sys.argv[1:],
					    'adf:hi:lm:M:o:rtv',
					    ['all',
					     'dmesg',
					     'fcfile=',
					     'help',
					     'input=',
					     'lastreload',
					     'module=',
					     'output=',
					     'requires'
					     'tefile',
					     'verbose'
					     ])
		for o,a in gopts:
			if o == "-a" or o == "--all":
				if input_ind or te_ind:
					usage()
				input=open("/var/log/messages", "r")
				auditlogs=1
			if o == "-d"  or o == "--dmesg":
				input=os.popen("/bin/dmesg", "r")
			if o == "-f" or o == "--fcfile":
				if a[0]=="-":
					usage()
				fc_file=a
			if o == "-h" or o == "--help":
				usage()
			if o == "-i"or o == "--input":
				if auditlogs  or a[0]=="-":
					usage()
				input_ind=1
				input=open(a, "r")
			if o == '--lastreload' or o == "-l":
				last_reload=1
			if o == "-m" or o == "--module":
				if module != "" or a[0]=="-":
					usage()
				module=a
			if o == "-M":
				if module != "" or output_ind  or a[0]=="-":
					usage()
				module=a
				outfile=a+".te"
				buildPP=1
				output=open(outfile, "w")
			if o == "-r" or o == "--requires":
				requires=1
			if o == "-t" or o == "--tefile":
				if auditlogs:
					usage()
				te_ind=1
			if o == "-o" or o == "--output":
				if module != ""  or a[0]=="-":
					usage()
				output=open(a, "a")
				output_ind=1
			if o == "-v" or o == "--verbose":
				verbose=1
				
		if len(cmds) != 0:
			usage()

		if fc_file != "" and not buildPP:
			usage("Error %s: Option -fc requires -M" % sys.argv[0])
			
		out=seruleRecords(input, last_reload, verbose, te_ind)

		if auditlogs:
			input=os.popen("ausearch -m avc")
			out.load(input)

		if buildPP:
			print ("Generating type enforcment file: %s.te" % module)
		output.write(out.out(requires, module))
		output.flush()
		if buildPP:
			cmd="checkmodule %s -m -o %s.mod %s.te" % (get_mls_flag(), module, module)
			print "Compiling policy: %s" % cmd
			rc=commands.getstatusoutput(cmd)
			if rc[0]==0:
				cmd="semodule_package -o %s.pp -m %s.mod" % (module, module)
				print cmd
				if fc_file != "":
					cmd = "%s -f %s" % (cmd, fc_file)
					
				print "Building package: %s" % cmd
				rc=commands.getstatusoutput(cmd)
				if rc[0]==0:
					print ("\n******************** IMPORTANT ***********************\n")
					print ("In order to load this newly created policy package into the kernel,\nyou are required to execute \n\nsemodule -i %s.pp\n\n" % module)
				else:
					errorExit(rc[1])
			else:
				errorExit(rc[1])

	except getopt.error, error:
		errorExit("Options Error " + error.msg)
	except ValueError, error:
		errorExit(error.args[0])
	except IOError, error:
		errorExit(error.args[1])
	except KeyboardInterrupt, error:
		sys.exit(0)
