#!/usr/bin/python 
"""
# disass.py v2.0 by atlas
#
# Syntax:  ./disass.py <binary-executable>
#
# disass.py will use a few simply objdump calls to gather GOT, PLT, and Disassembly information.
# GOT addresses are tied to PLT calls and the PLT lines are tagged with function names
# Then, full disassembly is scanned for references to PLT calls, and those lines are labeled
# with the appropriate function call name.  Tested only with *nix.
#
# Yes, this may seem elementary, but I found it helpful so here it is.
# @145
#     Converted to Python from Perl in 2/2006

Changelog:
	2.0:
	 Changed design to OOP
	 Enhanced Sub tagging, including naming
	 Added Variable tagging
	 
	1.1:
	 Translated to Python
	 Added 
	 Bugfixes
	 
	1.0:
	 Initial version in Perl

"""

import os
import sys
import sre

############  psyco optimization.  If you do not use psyco, comment this out ###############
import psyco
psyco.log()
psyco.full()
############

binary=""
findoffset = True
findsyms = True
finddynsyms = True
printopcodes = True
HELP=False


from DisassGOT import *
from DisassArchiveHEADERS import *
from DisassFileHEADERS import *
from DisassProgramHEADERS import *
from DisassSectionHEADERS import *
from DisassSYMS import *
from DisassDynSYMS import *
from DisassSubRoutine import *
from DisassDATA import *

VER=2.0

## printSyntax() does just that. ##
def printSyntax():
	print >>sys.stderr, """
disass v%0.2f Disassembling enhancer

Syntax:  %s [--txt|--html] <binary-to-disass>

"""  % (VER, sys.argv[0])


# Sub Parser Constants
SRCH_SUBTITLE = "^[0-9a-fA-F]{7,8}\s*<.*>:"
SRCH_PUSHEBP = "55.*push\s+%ebp"
SRCH_MOVESPEBP = "89.*[eE]5.*mov\s+%esp,\s*%ebp"
SRCH_RET = "[cC]3.*ret"
SRCH_CALL = "call "
SRCH_JMP = "j[abezlgnm][ezps]{0,1} "
CONST_ENDSUB = "\nRETURN\n"
CONST_UNNAMEDSUB = 'unnamed_sub_'

def chopaddress(address):
	addy = address.strip(' \*$')
	try:
		if (addy[1] == 'x'):
			addy = addy[2:]
		if (addy[0] == '0'):
			addy = addy[1:]
	except:
		pass
	return addy
			

class disass:
	GOT = None
	SYMBOLS = None
	DYNSYMBOLS = None
	ARCHIVEHEADERS = None
	PROGRAMHEADERS = None
	FILEHEADERS = None
	SECTIONHEADERS = None
	LOOKUPS = {}
	
	
	BREAK = []
	CALLS = {}
	JMPS = {}
	CNDJMPS = {}
	JMPTRGTS = {}
	SUBS = {}
	REPRCACHE = {}
	
	HEAD = []
	asm = []
	
	colors = [ "#330099","#993399","#3333FF","#009900","#FFCC33","#990000","#00CCCC","#FFFF00","#336666" ]
	

	###################################
	def __init__(self, binary):
		
		self.GOT =  DisassGOT(binary, self)
		self.ARCHIVEHEADERS = DisassArchiveHEADERS(binary, self)
		self.FILEHEADERS = DisassFileHEADERS(binary, self)
		self.PROGRAMHEADERS = DisassProgramHEADERS(binary, self)
		self.SECTIONHEADERS = DisassSectionHEADERS(binary, self)
		self.SYMBOLS = DisassSYMS(binary, self)
		self.DYNSYMBOLS = DisassDynSYMS(binary, self)
		self.DATA = DisassDATA(binary, self)
		
		self.unregisteredJMPS = {}
		self.unregisteredCalls = {}
		self.cursection = ""
		
#		self.HEAD = os.popen('objdump -f %s'%binary).readlines()
#		for i in range(len(self.HEAD)):
#			self.HEAD[i].strip()
		
		self.asm = os.popen('objdump -S '+binary).readlines()
		
	def addBREAK(self,string):
		""" addBREAK adds a string to the BREAK array.  This is expected to be in the 'finished' format already.  ie.  either a recognized sub name or *0x08048942 """
		self.BREAK.append(string)
		return (len(self.BREAK) - 1)

	def getSubName(self,address):
		""" getSubName is intended to be called by a subroutine requesting it's own name
		It loops through GOT, SYMBOLS, DYNSYMBOLS and SUBS looking for this address
		If a name cannot be found for this subrouting, one will be created"""
		#print("getSubName::%s::"%address)
		name = self.GOT.getSubName(address, True)
		if (len(name) == 0): name = self.DYNSYMBOLS.getSubName(address,True)
		if (len(name) == 0): name = self.SYMBOLS.getSubName(address,True)
		if (len(name) == 0): 
			sub = self.SUBS.get(address)
			if (sub): name = sub.name
		if (len(name) == 0): name = "%s%d" % (CONST_UNNAMEDSUB,len(self.SUBS))
		return name

	def registerLookup(self,address,object):
		addy = chopaddress(address)
		#print >>sys.stderr,("%s: %s"%(addy,object))
		self.LOOKUPS.setdefault(addy,object)

	def getRepr(self,address):
		""" getRepr searches for a name for a given address.
		It loops through GOT, SYMBOLS, DYNSYMBOLS and SUBS looking for this address
		If a name cannot be found for this subroutine, a modified name is returned -eg. sub + 0004
		Soon to be deprecated in favor of DisassSubRoutine.getRepr(), which uses a lookup table and is the source of most getRepr() calls"""

		address = chopaddress(address)

		### First, check the REPR CACHE..
		name = self.REPRCACHE.get(address,"")
		if (len(name) > 0):  
			self.REPRCACHE.setdefault(address,name)
			return name

		### Next check the GOT.  Is it a dyn linked library?
		name = self.GOT.getSubName(address)
		if (len(name) > 0):  
			self.REPRCACHE.setdefault(address,name)
			return name

		### Next check existing Subroutines.  Perhaps we're calling another locally defined sub.
		sub = self.SUBS.get(address)
		if (sub): 
			name = "\t local::%s" % sub.name
			self.REPRCACHE.setdefault(address,name)
			return name

		### Next, check the .rodata section.  Perhaps it's a string?
		name = self.DATA.getRepr(address)
		if (len(name) > 0):  
			self.REPRCACHE.setdefault(address,name)
			return name

		### Next check the Dynamic Symbols table  (Is this beneficial?  or just additional weight?
		name = self.DYNSYMBOLS.getSubName(address)
		if (len(name) > 0):  
			self.REPRCACHE.setdefault(address,name)
			return name

		### Next check the Symbols table
		name = self.SYMBOLS.getSubName(address)
		if (len(name) > 0):  
			self.REPRCACHE.setdefault(address,name)
			return name

		### Next, Look through the Subroutines.  Is it IN one of them?  If so, 
		try:
			# This will determine what local sub owns this address
			addrint = int(address,16)
			# Check SubRoutines for location
			subrange = ""
			keys = self.SUBS.keys()
			keys.sort()
			##### NEEDS REWORKING OR DELETION.  Use counter, and compare between this and the next sub. #####
			for sub in keys:
				subint = int(sub, 16)
				diff = addrint-subint
				#print >>sys.stderr,("disass.getRepr.subsub  :%x:%x:%s:"%(addrint,subint,sub))
				if (addrint >= subint):
					subrange = "(%s %+.4x)"%(self.SUBS.get(sub).name,diff)
			name += subrange
			#print >>sys.stderr,("getRepr(): %s"%subrange)

		except Exception,e:
			print >>sys.stderr,("getRepr(): %s"%e)

		self.REPRCACHE.setdefault(address,name)
		return name
		
	###### Stage 2: Find all subroutines using "ret" and "push %ebp"/"mov 	%	esp,ebp"######
	def buildSubs(self):
		SUB = None
		for EACH in range(len(self.asm)):
			thisline = self.asm[EACH]
			#   Subroutine management.  When to start or stop a sub...
 			if (sre.search(SRCH_SUBTITLE, thisline) != None or (sre.search(SRCH_PUSHEBP, thisline) != None and sre.search(SRCH_MOVESPEBP,self.asm[EACH+1]) != None and SUB and len(SUB.lines) > 2)):
				address = thisline[:thisline.index(":")]
				addrbrack = address.find("<")
				if (addrbrack >=0):  address = address[:addrbrack]
				
				#chop off leading spaces...
				address = chopaddress(address)
				#self.BREAK.append(" *0x%s" % address)
				
				if (SUB): SUB.setName()
				
				SUB = DisassSubRoutine(self, address)
				self.SUBS.setdefault(address, SUB)
				
				#print >>sys.stderr,(address)
			
			#  Subroutine line additions..
			if (SUB == None):
				self.HEAD.append(thisline)
			else:
				SUB.append(thisline)
			
			#  Look for "ret" instruction.  Insert a notification, even while not ending the subroutine (which would require a new one)
			#	This is done separately from Sub-management so that instructions that fall between subs are not lost...  
			if (sre.search(SRCH_RET, thisline) != None):
				if (SUB):
					SUB.append(CONST_ENDSUB)
				else:
					self.HEAD.append(CONST_ENDSUB)
					
					
		self.LOOKUPKEYS = self.LOOKUPS.keys()
		self.LOOKUPKEYS.sort()
		
		for sub in self.SUBS.values(): 
			sub.process()
					


######################### Stage 2: Find all arbitrary memory calls and label.  Use PEFILE and PyElf to determine Strings ###############################
	def tagMemAccess(self):
		pass
	
	######################### Stage 2: Find all non-PLT calls and tag them with name of subroutine they are calling ###############################
	def tagCalls(self):
		pass

######################### Stage 3: Find all jmp, je, jz, jl, jg, jge, jle, jnz, jns calls #############################
	def tagJMPs(self):
		pass

	### Check whether inside current sub
	##### If same sub and a jump backward, mark that section as a loop
	##### ASCII art the jmps?   HTML?
	# For each jmp block, check self.asm[line][somechar] for [/\|].  
		## If found, check [somechar+1] until an empty char is found
		# Insert a "|" for each line in between and a /- or \- as appropriate  (HTML and colors?)

		# Store "destination" in array
		## Add a blank line after each destination for easier reading



	def registerCall(self, src, dest, name):
		""" Centralized registration of JMP calls.  
			dest is target of JMP
			src is the line doing the JMP
			title is what should be printed. maybe  """
		registered = False
		for sub in self.SUBS.values():
			if (sub.addCaller(src,dest,name)):
				registered = True
		
		if (not registered):
			self.unregisteredCalls.setdefault(src, dest)

	def registerJMP(self, src, dest, name):
		""" Centralized registration of JMP calls.  
			dest is target of JMP
			src is the line doing the JMP
			title is what should be printed. maybe  """
		registered = False
		for sub in self.SUBS.values():
			if (sub.addJMPer(src,dest,name)):
				registered = True
		
		if (not registered):
			self.unregisteredJMPS.setdefault(src, dest)





	def printTXT(self, outfile = sys.stdout):
		print >>outfile,("DISASSEMBLY:")
		self.FILEHEADERS.printTXT(outfile)
		self.ARCHIVEHEADERS.printTXT(outfile)
		self.PROGRAMHEADERS.printTXT(outfile)
		self.SECTIONHEADERS.printTXT(outfile)
		self.GOT.printTXT(outfile)
		self.SYMBOLS.printTXT(outfile)
		
		# Print the disassembly
		print >>outfile,("".join(self.HEAD))

		subs = self.SUBS.keys()
		subs.sort()
		for sub in subs:
			self.SUBS.get(sub).printTXT()
			
		print >>outfile,("\n\nBreakpoints for each \"call\":\n");
		for brk in self.BREAK:
			print >>outfile,(" break %s" % brk)
	
		print >>outfile,("""\nDISPLAY SETTINGS/Basic
	display/i $pc
	display/x $edx
	display/x $ecx
	display/x $ebx
	display/x $eax
	display/32wx $ebp-92
	display/32xw $esp 
""")

	def printHTML(self,outfile = sys.stdout):
		print >>outfile,("DISASSEMBLY:")
		for i in  self.asm:
			print(i.strip())
		print >>outfile,("\n\nGOT:")
		for i in GOT:
			print >>outfile,(i.strip()) 
		print >>outfile,("\n\nHEADERS:")
		for i in FILEHEADERS:
			print >>outfile,(i.strip())
		print >>outfile,("\n\nSYMBOLS:")
		for i in SYMBOLS:
			print >>outfile,(i.strip())
		print >>outfile,("\n\nBreakpoints for each \"call\":\n");
		for brk in BREAK:
			print >>outfile,(" break *0x" + brk)
		#

		print >>outfile,("""\nDISPLAY SETTINGS/Basic
	display/i $pc
	display/x $edx
	display/x $ecx
	display/x $ebx
	display/x $eax
	display/32wx $ebp-92
	display/32xw $esp 
""")






#################################################
#if (len(sys.argv) < 2 or sys.argv[1] == "-h"):
#	printSyntax()
#	sys.exit(1)
#
#me = disass(sys.argv[1])
##me.tagPLT()
#me.buildSubs()
##me.tagJMPs()
##me.tagCalls()
##me.tagMemAccess()
#me.printTXT(sys.stdout)
#################################################


#################################################

while (len(sys.argv) > 1):
	item = sys.argv.pop(1)
	if (item.lower() == '--no-offset' or item == "-O"):   findoffset = False
	elif (item.lower() == '--no-syms' or item == "-S"):   findsyms = False
	elif (item.lower() == '--no-dynsyms' or item == "-D"):   finddynsyms = False
	elif (item.lower() == '--no-opcodes' or item == "-B"):   printopcodes = False
	else:
		binary=item

if (HELP or len(binary) == 0):
	printSyntax()
	sys.exit(1)

me = disass(binary)
#me.tagPLT()
me.buildSubs()
#me.tagJMPs()
#me.tagCalls()
#me.tagMemAccess()
me.printTXT(sys.stdout)


#################################################


