From 6aa04b0c1bbb404b06f2a0e2c592ffc976b1ed25 Mon Sep 17 00:00:00 2001 From: Jennifer Taylor Date: Sat, 24 Apr 2021 18:02:26 +0000 Subject: [PATCH] Implement recursive if detection. --- bemani/format/afp/decompile.py | 294 +++++++++++++++++++++++++++++++-- 1 file changed, 278 insertions(+), 16 deletions(-) diff --git a/bemani/format/afp/decompile.py b/bemani/format/afp/decompile.py index 292611b..6519ceb 100644 --- a/bemani/format/afp/decompile.py +++ b/bemani/format/afp/decompile.py @@ -101,6 +101,13 @@ class GotoStatement(ConvertedAction): return f"goto label_{self.location};" +class NullReturnStatement(ConvertedAction): + # A statement which directs the control flow to the end of the code, but + # does not pop the stack to return + def __repr__(self) -> str: + return "return;" + + class IntermediateIfStatement(ConvertedAction): def __init__(self, parent_action: IfAction, true_actions: Sequence[ArbitraryOpcode], false_actions: Sequence[ArbitraryOpcode], negate: bool) -> None: self.parent_action = parent_action @@ -159,7 +166,7 @@ class ByteCodeChunk: ) -ArbitraryCodeChunk = Union[ByteCodeChunk, "Loop"] +ArbitraryCodeChunk = Union[ByteCodeChunk, "Loop", "IfBody"] class Loop: @@ -197,6 +204,42 @@ class Loop: ) +class IfBody: + def __init__(self, id: int, true_chunks: Sequence[ArbitraryCodeChunk], false_chunks: Sequence[ArbitraryCodeChunk], next_chunk: Optional[int], previous_chunk: int, negate: bool) -> None: + # The ID is usually the chunk that other chunks point into. + self.id = id + + # If bodies are a bit special compared to Loops, we know the previous and next chunks + # for all of them. + self.previous_chunks: List[int] = [previous_chunk] + self.next_chunks: List[int] = [next_chunk] if next_chunk is not None else [] + self.true_chunks = list(true_chunks) + self.false_chunks = list(false_chunks) + self.negate = negate + + def __repr__(self) -> str: + true_entries: List[str] = [] + for chunk in self.true_chunks: + true_entries.extend([f" {s}" for s in str(chunk).split(os.linesep)]) + + false_entries: List[str] = [] + for chunk in self.false_chunks: + false_entries.extend([f" {s}" for s in str(chunk).split(os.linesep)]) + + return ( + f"IfBody({os.linesep}" + + f" Negated: {self.negate}{os.linesep}" + + f" ID: {self.id}{os.linesep}" + + (f" Previous Chunks: {', '.join(str(n) for n in self.previous_chunks)}{os.linesep}" if self.previous_chunks else f" Start Chunk{os.linesep}") + + f" True Chunks:{os.linesep}" + + f"{os.linesep.join(true_entries)}{os.linesep}" + + f" False Chunks:{os.linesep}" + + f"{os.linesep.join(false_entries)}{os.linesep}" + + (f" Next Chunks: {', '.join(str(n) for n in self.next_chunks)}{os.linesep}" if self.next_chunks else f" End Chunk{os.linesep}") + + ")" + ) + + class BitVector: def __init__(self, length: int, init: bool = False) -> None: self.__bits: Dict[int, bool] = {i: init for i in range(length)} @@ -466,7 +509,7 @@ class ByteCodeDecompiler(VerboseOutput): return (sorted(chunks, key=lambda c: c.id), offset_to_id) - def __get_entry_block(self, chunks: Sequence[ByteCodeChunk]) -> int: + def __get_entry_block(self, chunks: Sequence[Union[ByteCodeChunk, Loop]]) -> int: start_id: int = -1 for chunk in chunks: if not chunk.previous_chunks: @@ -548,10 +591,6 @@ class ByteCodeDecompiler(VerboseOutput): last_action = chunk.actions[-1] if isinstance(last_action, AP2Action): - if last_action.opcode in [AP2Action.THROW, AP2Action.RETURN]: - # Ignore these for now, we'll fix these up in a later stage. - continue - if last_action.opcode == AP2Action.JUMP: # This is either an unconditional break/continue or an # internal jump. @@ -568,8 +607,12 @@ class ByteCodeDecompiler(VerboseOutput): chunk.actions[-1] = ContinueStatement() chunk.next_chunks = [] elif next_chunk not in internal_jump_points: - self.vprint("Converting jump to external point into goto statement.") - chunk.actions[-1] = GotoStatement(next_chunk) + if next_chunk == offset_map[self.bytecode.end_offset]: + self.vprint("Converting jump to external point into return statement.") + chunk.actions[-1] = NullReturnStatement() + else: + self.vprint("Converting jump to external point into goto statement.") + chunk.actions[-1] = GotoStatement(next_chunk) chunk.next_chunks = [] continue @@ -592,8 +635,12 @@ class ByteCodeDecompiler(VerboseOutput): true_action = ContinueStatement() chunk.next_chunks = [n for n in chunk.next_chunks if n != true_jump_point] elif true_jump_point not in internal_jump_points: - self.vprint("Converting jump if true to external point into goto statement.") - true_action = GotoStatement(true_jump_point) + if true_jump_point == offset_map[self.bytecode.end_offset]: + self.vprint("Converting jump if true to external point into return statement.") + true_action = NullReturnStatement() + else: + self.vprint("Converting jump if true to external point into goto statement.") + true_action = GotoStatement(true_jump_point) chunk.next_chunks = [n for n in chunk.next_chunks if n != true_jump_point] false_action: Optional[ConvertedAction] = None @@ -606,8 +653,12 @@ class ByteCodeDecompiler(VerboseOutput): false_action = ContinueStatement() chunk.next_chunks = [n for n in chunk.next_chunks if n != false_jump_point] elif false_jump_point not in internal_jump_points: - self.vprint("Converting jump if false to external point into goto statement.") - false_action = GotoStatement(false_jump_point) + if false_jump_point == offset_map[self.bytecode.end_offset]: + self.vprint("Converting jump if false to external point into return statement.") + false_action = NullReturnStatement() + else: + self.vprint("Converting jump if false to external point into goto statement.") + false_action = GotoStatement(false_jump_point) chunk.next_chunks = [n for n in chunk.next_chunks if n != false_jump_point] if true_action is None and false_action is not None: @@ -698,7 +749,9 @@ class ByteCodeDecompiler(VerboseOutput): new_loop = Loop(header, [chunks_by_id[i] for i in blocks]) # Eliminate jumps that are to the beginning/end of the loop to - # make if statement detection later on easier. + # make if statement detection later on easier. This also breaks + # the graph at any spot where we successfully converted a jump + # to a break/continue/goto. new_loop = self.__analyze_loop_jumps(new_loop, offset_map) chunks_by_id[header] = new_loop @@ -731,8 +784,211 @@ class ByteCodeDecompiler(VerboseOutput): return [chunks_by_id[i] for i in chunks_by_id] - def __separate_ifs(self, chunks: Sequence[Union[ByteCodeChunk, Loop]], offset_map: Dict[int, int]) -> List[ArbitraryCodeChunk]: - return [c for c in chunks] + def __break_graph(self, chunks: Sequence[Union[ByteCodeChunk, Loop]], offset_map: Dict[int, int]) -> None: + for chunk in chunks: + if chunk.id == offset_map[self.bytecode.end_offset]: + # Don't examine the sentinel we keep around as a jump point for returns. + continue + + if isinstance(chunk, Loop): + self.vprint(f"Entering into loop {chunk.id} to break graph...") + + # At this point, we know chunk.chunks is a Union[ByteCodeChunk, Loop] because we haven't run + # any if detection yet. + self.__break_graph(cast(List[Union[ByteCodeChunk, Loop]], chunk.chunks), offset_map) + else: + # Examine the last instruction. + last_action = chunk.actions[-1] + if isinstance(last_action, AP2Action): + if last_action.opcode in [AP2Action.THROW, AP2Action.RETURN]: + # The last action already dictates what we should do here. Break + # the chain at this point. + self.vprint(f"Breaking chain on {chunk.id} because it is a {last_action}.") + chunk.next_chunks = [] + elif len(chunk.next_chunks) == 1 and chunk.next_chunks[0] == offset_map[self.bytecode.end_offset]: + # The jump point for this is the end of the function. If it is a jump, + # then we should replace it with a return. If it is not a jump, we should + # add a return. + if last_action.opcode == AP2Action.JUMP: + self.vprint(f"Converting jump to end of code in {chunk.id} into a null return.") + chunk.actions[-1] = NullReturnStatement() + else: + self.vprint(f"Converting fall-through to end of code in {chunk.id} into a null return.") + chunk.actions.append(NullReturnStatement()) + chunk.next_chunks = [] + + def __find_shallowest_successor(self, start_chunk: int, chunks_by_id: Dict[int, ArbitraryCodeChunk]) -> Optional[int]: + if len(chunks_by_id[start_chunk].next_chunks) != 2: + # We don't care about this, the successor is the next chunk! + raise Exception("Logic error!") + + left, right = chunks_by_id[start_chunk].next_chunks + visited: Set[int] = set() + + # First, let's find all the successors to the left side. + candidates: List[int] = [left] + while candidates: + for candidate in candidates: + visited.add(candidate) + + new_candidates = [] + for candidate in candidates: + # We can avoid re-traversing what we've already traversed, as we only want to color + # in the part of the tree that we're interested in. + new_candidates.extend([c for c in chunks_by_id[candidate].next_chunks if c not in visited]) + candidates = new_candidates + + # Now, lets do the same with the right, and the first one we encounter that's visited is our guy. + candidates = [right] + while candidates: + for candidate in candidates: + if candidate in visited: + return candidate + + new_candidates = [] + for candidate in candidates: + # We can't take the same shortcut here as above, as we are trying to ask the question + # of what's the shallowest successor, not color them in. + new_candidates.extend(chunks_by_id[candidate].next_chunks) + candidates = new_candidates + + # If we didn't find a successor, that means one of the control paths leads to end of execution. + return None + + def __gather_chunks(self, start_chunk: int, end_chunk: Optional[int], chunks_by_id: Dict[int, ArbitraryCodeChunk]) -> List[ArbitraryCodeChunk]: + visited: Set[int] = set() + chunks: List[ArbitraryCodeChunk] = [] + candidates: List[int] = [start_chunk] + + while candidates: + first_candidate = candidates.pop() + if first_candidate in visited: + # We already visited this node. + continue + + if end_chunk is None or first_candidate != end_chunk: + chunks.append(chunks_by_id[first_candidate]) + visited.add(first_candidate) + candidates.extend(chunks_by_id[first_candidate].next_chunks) + + # The chunk list is all chunks that belong in this sequence. Now, kill any pointers to the end chunk. + if end_chunk is not None: + for chunk in chunks: + chunk.next_chunks = [n for n in chunk.next_chunks if n != end_chunk] + + return chunks + + def __separate_ifs(self, start_id: int, chunks: Sequence[ArbitraryCodeChunk], offset_map: Dict[int, int]) -> List[ArbitraryCodeChunk]: + chunks_by_id: Dict[int, ArbitraryCodeChunk] = {chunk.id: chunk for chunk in chunks} + cur_id = start_id + + self.vprint(f"Separating if statements out of graph starting at {start_id}") + + while chunks_by_id[cur_id].next_chunks: + cur_chunk = chunks_by_id[cur_id] + if isinstance(cur_chunk, Loop): + self.vprint(f"Examining loop {cur_chunk.id} body for if statements...") + cur_chunk.chunks = self.__separate_ifs(cur_chunk.id, cur_chunk.chunks, offset_map) + self.vprint(f"Finished examining loop {cur_chunk.id} body for if statements...") + + if len(chunks_by_id[cur_id].next_chunks) == 1: + # This is just a goto/chunk, move on to the next one. + cur_id = chunks_by_id[cur_id].next_chunks[0] + continue + + if not isinstance(cur_chunk, ByteCodeChunk): + # We should only be looking at bytecode chunks at this point, all other + # types should have a single next chunk. + raise Exception("Logic error!") + + last_action = cur_chunk.actions[-1] + if not isinstance(last_action, IfAction): + # This needs, again, to be an if statement. + raise Exception("Logic error!") + if len(chunks_by_id[cur_id].next_chunks) != 2: + # This needs to be an if statement. + raise Exception("Logic error!") + + # This should be an if statement. Figure out if it is an if-else or an + # if, and if both branches return. + if_end = self.__find_shallowest_successor(cur_id, chunks_by_id) + + # This is a normal if or if-else, let's compile the true and false + # statements. + true_jump_point = offset_map[last_action.jump_if_true_offset] + false_jump_points = [n for n in cur_chunk.next_chunks if n != true_jump_point] + if len(false_jump_points) != 1: + self.vprint(chunks) + self.vprint(f"Jump has {last_action.jump_if_true_offset} which maps to {offset_map[last_action.jump_if_true_offset]}") + self.vprint(f"Jump branches to {cur_chunk.next_chunks}") + raise Exception("Logic error!") + false_jump_point = false_jump_points[0] + + if true_jump_point == false_jump_point: + # This should never happen. + raise Exception("Logic error!") + + self.vprint(f"Chunk ID {cur_id} is an if statement with true node {true_jump_point} and false node {false_jump_point} and ending at {if_end}") + + true_chunks: List[ArbitraryCodeChunk] = [] + if true_jump_point != if_end: + self.vprint(f"Gathering true path starting with {true_jump_point} and ending with {if_end} and detecting if statements within it as well.") + + # First, grab all the chunks in this if statement body. + true_chunks = self.__gather_chunks(true_jump_point, if_end, chunks_by_id) + self.vprint(true_chunks) + + # Delete these chunks from our chunk mapping since we're putting them in an if body. + for chunk in true_chunks: + del chunks_by_id[chunk.id] + + # Now, recursively attempt to detect if statements within this chunk as well. + true_chunks = self.__separate_ifs(true_jump_point, true_chunks, offset_map) + + false_chunks: List[ArbitraryCodeChunk] = [] + if false_jump_point != if_end: + self.vprint(f"Gathering false path starting with {false_jump_point} and ending with {if_end} and detecting if statements within it as well.") + + # First, grab all the chunks in this if statement body. + false_chunks = self.__gather_chunks(false_jump_point, if_end, chunks_by_id) + self.vprint(false_chunks) + + # Delete these chunks from our chunk mapping since we're putting them in an if body. + for chunk in false_chunks: + del chunks_by_id[chunk.id] + + # Now, recursively attempt to detect if statements within this chunk as well. + false_chunks = self.__separate_ifs(false_jump_point, false_chunks, offset_map) + + if false_chunks and (not true_chunks): + negate = True + true_chunks = false_chunks + false_chunks = [] + if_id = false_jump_point + else: + negate = False + if_id = true_jump_point + + if (not true_chunks) and (not false_chunks): + # We should have at least one! + raise Exception("Logic error!") + + # Add a new if body that this current chunk points to. At this point, chunks_by_id contains + # none of the chunks in the true or false bodies of the if, so we add it back to the graph + # in the form of an IfBody. + self.vprint(f"Created new IfBody for chunk {cur_id} to point at, ending at {if_id}") + chunks_by_id[if_id] = IfBody(if_id, true_chunks, false_chunks, if_end, cur_id, negate) + chunks_by_id[cur_id].next_chunks = [if_id] + + if if_end is not None: + # Skip over the if, we already analyzed it. + cur_id = if_end + else: + # This if statement encompases all the rest of the statements, we're done. + break + + self.vprint(f"Finished separating if statements out of graph starting at {start_id}") + return [c for _, c in chunks_by_id.items()] def __decompile(self) -> str: # First, we need to construct a control flow graph. @@ -747,9 +1003,15 @@ class ByteCodeDecompiler(VerboseOutput): self.vprint("Identifying and separating loops...") chunks_and_loops = self.__separate_loops(chunks, dominators, offset_map) + # Now, break the graph anywhere where we have control + # flow that ends the execution (return, throw, goto end). + self.vprint("Breaking control flow graph on non-returnable statements...") + self.__break_graph(chunks_and_loops, offset_map) + # Now, identify any remaining control flow logic. self.vprint("Identifying and separating ifs...") - chunks_loops_and_ifs = self.__separate_ifs(chunks_and_loops, offset_map) + start_id = self.__get_entry_block(chunks_and_loops) + chunks_loops_and_ifs = self.__separate_ifs(start_id, chunks_and_loops, offset_map) # At this point, we *should* have a directed graph where there are no # backwards refs and every fork has been identified as an if. This means