diff --git a/bemani/common/pe.py b/bemani/common/pe.py index 43ed3f5..6c04168 100644 --- a/bemani/common/pe.py +++ b/bemani/common/pe.py @@ -27,18 +27,16 @@ class Memory: class Registers: def __init__(self) -> None: - self.eax = 0 - self.ebx = 0 - self.ecx = 0 - self.edx = 0 - self.esi = 0 - self.edi = 0 - self.ebp = 0 - self.esp = 0xFFFFFFFF + self.rax = 0 + self.rbx = 0 + self.rcx = 0 + self.rdx = 0 + self.rsi = 0 + self.rdi = 0 + self.rbp = 0 + self.rsp = 0xFFFFFFFF self.zf = False - self.of = False - self.cf = False class PEFile: @@ -142,8 +140,20 @@ class PEFile: if size is None: raise Exception(f"Could not determine size of {mnemonic} operation!") result = fetch(registers, memory, size, src) - registers.esp -= 4 - assign(registers, memory, size, "[esp]", result) + registers.rsp -= size + assign(registers, memory, size, "[rsp]" if self.is_64bit() else "[esp]", result) + + elif mnemonic == "pop": + dest = formatter.format_operand(inst, 0) + + vprint(f"pop {dest}") + + size = get_size(src) + if size is None: + raise Exception(f"Could not determine size of {mnemonic} operation!") + result = fetch(registers, memory, size, "[rsp]" if self.is_64bit() else "[esp]") + assign(registers, memory, size, dest, result) + registers.rsp += size elif mnemonic == "test": op1 = formatter.format_operand(inst, 0) @@ -156,13 +166,11 @@ class PEFile: raise Exception(f"Could not determine size of {mnemonic} operation!") result = fetch(registers, memory, size, op1) & fetch(registers, memory, size, op2) registers.zf = result == 0 - registers.of = False - registers.cf = False elif mnemonic == "jne": dest = formatter.format_operand(inst, 0) - vprint(f"jne {dest}") + vprint(f"jnz {dest}") if not registers.zf: destination = get_value(dest) @@ -170,12 +178,14 @@ class PEFile: raise Exception(f"Jumping to unsupported destination {dest}") dest_off = self.virtual_to_physical(destination) - if dest_off < start or dest_off >= end: + if dest_off == end: + loc = len(insts) + elif dest_off < start or dest_off > end: raise Exception(f"Jumping to {hex(destination)} which is outside of our evaluation range!") - - decoder = Decoder(64 if self.is_64bit() else 32, self.data[dest_off:end], ip=self.physical_to_virtual(dest_off)) - insts = [i for i in decoder] - loc = 0 + else: + decoder = Decoder(64 if self.is_64bit() else 32, self.data[dest_off:end], ip=self.physical_to_virtual(dest_off)) + insts = [i for i in decoder] + loc = 0 elif mnemonic == "jmp": dest = formatter.format_operand(inst, 0) @@ -187,12 +197,14 @@ class PEFile: raise Exception(f"Jumping to unsupported destination {dest}") dest_off = self.virtual_to_physical(destination) - if dest_off < start or dest_off >= end: + if dest_off == end: + loc = len(insts) + elif dest_off < start or dest_off > end: raise Exception(f"Jumping to {hex(destination)} which is outside of our evaluation range!") - - decoder = Decoder(64 if self.is_64bit() else 32, self.data[dest_off:end], ip=self.physical_to_virtual(dest_off)) - insts = [i for i in decoder] - loc = 0 + else: + decoder = Decoder(64 if self.is_64bit() else 32, self.data[dest_off:end], ip=self.physical_to_virtual(dest_off)) + insts = [i for i in decoder] + loc = 0 elif mnemonic == "or": dest = formatter.format_operand(inst, 0) @@ -240,6 +252,11 @@ class PEFile: def sanitize(indirect: str) -> str: + """ + Given an indirect address or a value from iced-x86 as formatted by the + operand formatter, sanitize it by getting rid of size specifiers. + """ + if indirect[:5] == "near ": indirect = indirect[5:] @@ -255,10 +272,21 @@ def sanitize(indirect: str) -> str: if indirect[:6] == "dword ": indirect = indirect[6:] + if indirect[:6] == "qword ": + indirect = indirect[6:] + return indirect def get_address(registers: Registers, indirect: str) -> Optional[int]: + """ + Given an indirect reference as formatted by the iced-x86 operand formatter, + resolve it to an actual 32-bit address that we should load from or store to. + This optionally supports indirect register address format so that we can + conveniently specify fetches and stores from the stack. If the value we + receive is not actually an indirect reference, return None. + """ + indirect = sanitize(indirect) if indirect[0] == "[" and indirect[-1] == "]": @@ -267,14 +295,20 @@ def get_address(registers: Registers, indirect: str) -> Optional[int]: if val[-1] == 'h': return int(val[:-1], 16) - if val == "esp": - return registers.esp + if val in {"rsp", "esp", "sp", "spl"}: + return registers.rsp raise Exception(f"Unsupported indirect address {indirect}!") return None def get_value(immediate: str) -> Optional[int]: + """ + Given an immediate value as formatted by the iced-x86 operand formatter, + resolve it to an immediate integer. If the value we receive is not + actually an immediate value, return None. + """ + immediate = sanitize(immediate) if immediate[-1] == "h": @@ -289,27 +323,45 @@ def get_value(immediate: str) -> Optional[int]: return None -def get_size(reg: str) -> Optional[int]: - if reg in {'eax', 'ebx', 'ecx', 'edx', 'esp', 'ebp', 'esi', 'edi'}: +def get_size(operand: str) -> Optional[int]: + """ + Given an operand as formatted by the iced-x86 operand formatter, return + the size in bytes that that operand represents in a load or store. + Supports both registers and byte/word/dword/qword specifiers in front of + immediate values and indirect memory references. + """ + + if operand in {'rax', 'rbx', 'rcx', 'rdx', 'rsp', 'rbp', 'rsi', 'rdi'}: + return 8 + if operand in {'eax', 'ebx', 'ecx', 'edx', 'esp', 'ebp', 'esi', 'edi'}: return 4 - if reg in {'ax', 'bx', 'cx', 'dx', 'sp', 'bp', 'si', 'di'}: + if operand in {'ax', 'bx', 'cx', 'dx', 'sp', 'bp', 'si', 'di'}: return 2 - if reg in {'ah', 'al', 'bh', 'bl', 'ch', 'cl', 'dh', 'dl'}: + if operand in {'ah', 'al', 'bh', 'bl', 'ch', 'cl', 'dh', 'dl', 'spl', 'bpl', 'sil', 'dil'}: return 1 - if reg[:5] == "byte ": + if operand[:5] == "byte ": return 1 - if reg[:5] == "word ": + if operand[:5] == "word ": return 2 - if reg[:6] == "dword ": + if operand[:6] == "dword ": return 4 + if operand[:6] == "qword ": + return 8 + return None def assign(registers: Registers, memory: Memory, size: int, loc: str, value: int) -> None: + """ + Given the registers and memory of our emulator, the size of the operation + performed, the location to assign to and the value we should assign, + compute where the assignment should happen and then execute it. + """ + address = get_address(registers, loc) if address is not None: if size == 1: @@ -323,58 +375,160 @@ def assign(registers: Registers, memory: Memory, size: int, loc: str, value: int memory.store(address, data) return + if loc == "rax": + registers.rax = value + return + + if loc == "rbx": + registers.rbx = value + return + + if loc == "rcx": + registers.rcx = value + return + + if loc == "rdx": + registers.rdx = value + return + + if loc == "rsp": + registers.rsp = value + return + + if loc == "rbp": + registers.rbp = value + return + + if loc == "rsi": + registers.rsi = value + return + + if loc == "rdi": + registers.rdi = value + return + if loc == "eax": - registers.eax = value + registers.rax = (registers.rax & 0xFFFFFFFF00000000) | (value & 0xFFFFFFFF) return if loc == "ebx": - registers.ebx = value + registers.rbx = (registers.rbx & 0xFFFFFFFF00000000) | (value & 0xFFFFFFFF) return if loc == "ecx": - registers.ecx = value + registers.rcx = (registers.rcx & 0xFFFFFFFF00000000) | (value & 0xFFFFFFFF) return if loc == "edx": - registers.edx = value + registers.rdx = (registers.rdx & 0xFFFFFFFF00000000) | (value & 0xFFFFFFFF) return if loc == "esp": - registers.esp = value + registers.rsp = (registers.rsp & 0xFFFFFFFF00000000) | (value & 0xFFFFFFFF) return if loc == "ebp": - registers.esp = value + registers.rbp = (registers.rbp & 0xFFFFFFFF00000000) | (value & 0xFFFFFFFF) return if loc == "esi": - registers.esi = value + registers.rsi = (registers.rsi & 0xFFFFFFFF00000000) | (value & 0xFFFFFFFF) return if loc == "edi": - registers.edi = value + registers.rdi = (registers.rdi & 0xFFFFFFFF00000000) | (value & 0xFFFFFFFF) + return + + if loc == "ax": + registers.rax = (registers.rax & 0xFFFFFFFFFFFF0000) | (value & 0xFFFF) + return + + if loc == "bx": + registers.rbx = (registers.rbx & 0xFFFFFFFFFFFF0000) | (value & 0xFFFF) + return + + if loc == "cx": + registers.rcx = (registers.rcx & 0xFFFFFFFFFFFF0000) | (value & 0xFFFF) + return + + if loc == "dx": + registers.rdx = (registers.rdx & 0xFFFFFFFFFFFF0000) | (value & 0xFFFF) + return + + if loc == "sp": + registers.rsp = (registers.rsp & 0xFFFFFFFFFFFF0000) | (value & 0xFFFF) + return + + if loc == "bp": + registers.rbp = (registers.rbp & 0xFFFFFFFFFFFF0000) | (value & 0xFFFF) + return + + if loc == "si": + registers.rsi = (registers.rsi & 0xFFFFFFFFFFFF0000) | (value & 0xFFFF) + return + + if loc == "di": + registers.rdi = (registers.rdi & 0xFFFFFFFFFFFF0000) | (value & 0xFFFF) + return + + if loc == "ah": + registers.rax = (registers.rax & 0xFFFFFFFFFFFF00FF) | ((value & 0xFF) << 8) return if loc == "al": - registers.eax = (registers.eax & 0xFFFFFF00) | (value & 0xFF) + registers.rax = (registers.rax & 0xFFFFFFFFFFFFFF00) | (value & 0xFF) + return + + if loc == "bh": + registers.rbx = (registers.rbx & 0xFFFFFFFFFFFF00FF) | ((value & 0xFF) << 8) return if loc == "bl": - registers.ebx = (registers.ebx & 0xFFFFFF00) | (value & 0xFF) + registers.rbx = (registers.rbx & 0xFFFFFFFFFFFFFF00) | (value & 0xFF) + return + + if loc == "ch": + registers.rcx = (registers.rcx & 0xFFFFFFFFFFFF00FF) | ((value & 0xFF) << 8) return if loc == "cl": - registers.ecx = (registers.ecx & 0xFFFFFF00) | (value & 0xFF) + registers.rcx = (registers.rcx & 0xFFFFFFFFFFFFFF00) | (value & 0xFF) + return + + if loc == "dh": + registers.rdx = (registers.rdx & 0xFFFFFFFFFFFF00FF) | ((value & 0xFF) << 8) return if loc == "dl": - registers.edx = (registers.edx & 0xFFFFFF00) | (value & 0xFF) + registers.rdx = (registers.rdx & 0xFFFFFFFFFFFFFF00) | (value & 0xFF) + return + + if loc == "spl": + registers.rsp = (registers.rsp & 0xFFFFFFFFFFFFFF00) | (value & 0xFF) + return + + if loc == "bpl": + registers.rbp = (registers.rbp & 0xFFFFFFFFFFFFFF00) | (value & 0xFF) + return + + if loc == "sil": + registers.rsi = (registers.rsi & 0xFFFFFFFFFFFFFF00) | (value & 0xFF) + return + + if loc == "dil": + registers.rdi = (registers.rdi & 0xFFFFFFFFFFFFFF00) | (value & 0xFF) return raise Exception(f"Unsupported destination {loc} for assign!") def fetch(registers: Registers, memory: Memory, size: int, loc: str) -> int: + """ + Given the registers and memory of our emulator, the size of the operation + performed and the location to fetch from, compute where the fetch should + happen and then execute it, returning the results of the fetch. + """ + address = get_address(registers, loc) if address is not None: if size == 1: @@ -396,58 +550,112 @@ def fetch(registers: Registers, memory: Memory, size: int, loc: str) -> int: return immediate & 0xFFFFFFFF raise Exception(f"Unsupported size {size} for immediate fetch!") + if loc == "rax": + return registers.rax + + if loc == "rbx": + return registers.rbx + + if loc == "rcx": + return registers.rcx + + if loc == "rdx": + return registers.rdx + + if loc == "rsi": + return registers.rsi + + if loc == "rdi": + return registers.rdi + + if loc == "rsp": + return registers.rsp + + if loc == "rbp": + return registers.rbp + if loc == "eax": - return registers.eax + return registers.rax & 0xFFFFFFFF if loc == "ebx": - return registers.ebx + return registers.rbx & 0xFFFFFFFF if loc == "ecx": - return registers.ecx + return registers.rcx & 0xFFFFFFFF if loc == "edx": - return registers.edx + return registers.rdx & 0xFFFFFFFF if loc == "esi": - return registers.esi + return registers.rsi & 0xFFFFFFFF if loc == "edi": - return registers.edi - - if loc == "esp": - return registers.esp + return registers.rdi & 0xFFFFFFFF if loc == "ebp": - return registers.esp + return registers.rbp & 0xFFFFFFFF + + if loc == "esp": + return registers.rsp & 0xFFFFFFFF if loc == "ax": - return registers.eax & 0xFFFF + return registers.rax & 0xFFFF if loc == "bx": - return registers.ebx & 0xFFFF + return registers.rbx & 0xFFFF if loc == "cx": - return registers.ecx & 0xFFFF + return registers.rcx & 0xFFFF if loc == "dx": - return registers.edx & 0xFFFF + return registers.rdx & 0xFFFF if loc == "si": - return registers.esi & 0xFFFF + return registers.rsi & 0xFFFF if loc == "di": - return registers.edi & 0xFFFF + return registers.rdi & 0xFFFF + + if loc == "bp": + return registers.rbp & 0xFFFF + + if loc == "sp": + return registers.rsp & 0xFFFF + + if loc == "ah": + return (registers.rax & 0xFF00) >> 8 if loc == "al": - return registers.eax & 0xFF + return registers.rax & 0xFF + + if loc == "bh": + return (registers.rbx & 0xFF00) >> 8 if loc == "bl": - return registers.ebx & 0xFF + return registers.rbx & 0xFF + + if loc == "ch": + return (registers.rcx & 0xFF00) >> 8 if loc == "cl": - return registers.ecx & 0xFF + return registers.rcx & 0xFF + + if loc == "dh": + return (registers.rdx & 0xFF00) >> 8 if loc == "dl": - return registers.edx & 0xFF + return registers.rdx & 0xFF + + if loc == "spl": + return registers.rsp & 0xFF + + if loc == "bpl": + return registers.rbp & 0xFF + + if loc == "sil": + return registers.rsi & 0xFF + + if loc == "dil": + return registers.rdi & 0xFF raise Exception(f"Unsupported source {loc} for fetch!")