#
# m4aInfo.py
#
# This work is released under the GNU GPL, version 2 or later.
#
import struct

_mapping = {
	'\xa9day':'Year',
	'\xa9gen':'Genre Name',
	'\xa9nam':'Title',
	'\xa9ART':'Artist',
	'\xa9alb':'Album',
	'\xa9wrt':'Composer',
	'\xa9too':'Encoder',
	'\xa9cmt':'Comment',
	'\xa9grp':'Grouping',
	'trkn':'Track Number',
	'disk':'Disc Number',
	'gnre':'Genre',
	'cpil':'Compilation',
	'covr':'Cover Art',
	'tmpo':'Tempo',
	'apid':'Apple ID',
	'cprt':'Copyright',
	'rtng':'Rating'
	}

class MP4Info:
	def __init__(self,fd, includeArt = False):
		self.fd = fd
		self.movieInfo = {'MPEG Version':4,'Layer':1}
		self.includeArt = includeArt
		self.macEpochOffset = 2082844800L
		self.scan()
		try:
			self.movieInfo['Bit Rate'] = int(0.5+self.movieInfo['Size']/self.movieInfo['Duration']*8)
		except: pass
	
	def atom(self):
		sizeTag = self.fd.read(8)
		(size,tag) = struct.unpack('!I4s',sizeTag)
		#print size,tag
		nextatom = self.fd.tell()+size-8
		return (tag,nextatom)
	
	def scan(self):
		try:
			while True:
				(tag,ends) = self.atom()
				#print "file",tag
				if tag=='moov': self.scanMoov(ends)
				if tag=='mdat': self.scanMdat(ends)
				self.fd.seek(ends)
		except: pass

	def scanMoov(self,moovEnds):
		while True:
			(tag,ends) = self.atom()
			#print "moov",tag
			if tag=='mvhd': self.scanMvhd(ends)
			if tag=='udta': self.scanUdta(ends)
			if tag=='trak': self.scanTrak(ends)
			self.fd.seek(ends)
			if ends>=moovEnds: break

	def scanMdat(self,mdatEnds):
		if not self.movieInfo.has_key('Size'):
			self.movieInfo['Size'] = 0
		self.movieInfo['Size'] = self.movieInfo['Size']+mdatEnds-self.fd.tell()

	def scanMvhd(self,mvhdEnds):
		size = mvhdEnds-self.fd.tell()
		mvhd = self.fd.read(24)
		(version,flags,creationTime,modTime,timeScale,duration,rate) = struct.unpack("!c3s5I",mvhd)
		self.movieInfo['Version'] = ord(version)
		self.movieInfo['Flags'] = (ord(flags[0])<<16)+(ord(flags[1])<<8)+ord(flags[2])
		self.movieInfo['Creation Time'] = creationTime - self.macEpochOffset
		self.movieInfo['Modification Time'] = modTime- self.macEpochOffset
#		self.movieInfo['Time Scale'] = timeScale
		self.movieInfo['Duration'] = duration/(1.0*timeScale)
#		self.movieInfo['Rate'] = rate/65536.0
	
	def scanUdta(self,udtaEnds):
		while True:
			(tag,ends) = self.atom()
			#print "udta",tag
			if tag=='meta': self.scanMeta(ends)
			self.fd.seek(ends)
			if ends>=udtaEnds: break

	def scanMeta(self,metaEnds):
		self.fd.seek(self.fd.tell()+4)
		while True:
			(tag,ends) = self.atom()
			#print "meta",tag
			if tag=='ilst': self.scanIlst(ends)
			self.fd.seek(ends)
			if ends>=metaEnds: break

	def scanIlst(self,ilstEnds):
		while True:
			(tag,ends) = self.atom()
			(payloadTag,payloadEnds) = self.atom()
			#print "ilst",tag,payloadTag
			if payloadTag=='data': self.scanData(tag,ends)
			self.fd.seek(ends)
			if ends>=ilstEnds: break

	def scanData(self,tag,dataEnds):
		size = dataEnds-self.fd.tell()
		long1 = struct.unpack("!I",self.fd.read(4))
		long2 = struct.unpack("!I",self.fd.read(4))
		#print tag,long1,long2
		s = self.fd.read(size-8)
		try:
			if tag=='trkn': s = struct.unpack("!IHxx",s)
		except: pass
		try:
			if tag=='disk': s = struct.unpack("!IH",s)
		except: pass
		try:
			if tag=='cpil': s = (ord(s)!=0)
		except: pass
		try:
			if tag=='tmpo': s = struct.unpack("!H",s)[0]
		except: pass
		try:
			if tag=='gnre': s = struct.unpack("!H",s)[0]
		except: pass
		try:
			if tag=='rtng': s = ord(s[0])
		except: pass
		try:
			if tag=='\xa9day': s = int(s[:4])
		except: pass
		if tag!='covr' or self.includeArt:
			self.movieInfo[_mapping.get(tag,tag)] = s
		
	def scanTrak(self,trakEnds):
		while True:
			(tag,ends) = self.atom()
			#print "trak",tag
			if tag=='mdia': self.scanMdia(ends)
			self.fd.seek(ends)
			if ends>=trakEnds: break

	def scanMdia(self,mdiaEnds):
		while True:
			(tag,ends) = self.atom()
			#print "mdia",tag
			if tag=='minf': self.scanMinf(ends)
			self.fd.seek(ends)
			if ends>=mdiaEnds: break

	def scanMinf(self,minfEnds):
		while True:
			(tag,ends) = self.atom()
			#print "minf",tag
			if tag=='stbl': self.scanStbl(ends)
			self.fd.seek(ends)
			if ends>=minfEnds: break

	def scanStbl(self,stblEnds):
		while True:
			(tag,ends) = self.atom()
			#print "stbl",tag
			if tag=='stsd': self.scanStsd(ends)
			self.fd.seek(ends)
			if ends>=stblEnds: break

	def scanStsd(self,stsdEnds):
		size = stsdEnds-self.fd.tell()
		#print 'stsd size',size
		data = self.fd.read(size)
		format = data[12:16]
		#print format
		if format=='mp4a' or format=='drms' or format=='samr':
			self.movieInfo['Sample Rate'] = struct.unpack("!I",data[40:44])[0]>>16
			self.movieInfo['Channels'] = struct.unpack("!H",data[32:34])[0]
			self.movieInfo['Bits Per Sample'] = struct.unpack('!H',data[34:36])[0]
			self.movieInfo['VBR'] = struct.unpack("!H",data[36:38])[0]==-2
		#print data.lower()

	def info(self):
		return self.movieInfo

if __name__=='__main__':
	import sys
	filePath = sys.argv[1]
	i = MP4Info(open(filePath,"rb"))
	print i.info()
