diff --git a/bemani/format/afp/decompile.py b/bemani/format/afp/decompile.py index 6519ceb..c020456 100644 --- a/bemani/format/afp/decompile.py +++ b/bemani/format/afp/decompile.py @@ -171,7 +171,7 @@ ArbitraryCodeChunk = Union[ByteCodeChunk, "Loop", "IfBody"] class Loop: def __init__(self, id: int, chunks: Sequence[ArbitraryCodeChunk]) -> None: - # The ID is usually the chunk that other chunks point into. + # The ID is the chunk that other chunks point into, aka the loop header. self.id = id # Calculate predecessors (who points into it) and successors (who we point out of). @@ -206,7 +206,8 @@ class Loop: class IfBody: def __init__(self, id: int, true_chunks: Sequence[ArbitraryCodeChunk], false_chunks: Sequence[ArbitraryCodeChunk], next_chunk: Optional[int], previous_chunk: int, negate: bool) -> None: - # The ID is usually the chunk that other chunks point into. + # The ID in this case is what the previous block points at. It does not + # have any bearing on the ID of the true and false chunks. self.id = id # If bodies are a bit special compared to Loops, we know the previous and next chunks @@ -509,7 +510,7 @@ class ByteCodeDecompiler(VerboseOutput): return (sorted(chunks, key=lambda c: c.id), offset_to_id) - def __get_entry_block(self, chunks: Sequence[Union[ByteCodeChunk, Loop]]) -> int: + def __get_entry_block(self, chunks: Sequence[ArbitraryCodeChunk]) -> int: start_id: int = -1 for chunk in chunks: if not chunk.previous_chunks: @@ -872,9 +873,11 @@ class ByteCodeDecompiler(VerboseOutput): candidates.extend(chunks_by_id[first_candidate].next_chunks) # The chunk list is all chunks that belong in this sequence. Now, kill any pointers to the end chunk. - if end_chunk is not None: - for chunk in chunks: + for chunk in chunks: + if end_chunk is not None: chunk.next_chunks = [n for n in chunk.next_chunks if n != end_chunk] + if chunk.id == start_chunk: + chunk.previous_chunks = [] return chunks @@ -884,13 +887,17 @@ class ByteCodeDecompiler(VerboseOutput): self.vprint(f"Separating if statements out of graph starting at {start_id}") - while chunks_by_id[cur_id].next_chunks: + while True: cur_chunk = chunks_by_id[cur_id] if isinstance(cur_chunk, Loop): self.vprint(f"Examining loop {cur_chunk.id} body for if statements...") cur_chunk.chunks = self.__separate_ifs(cur_chunk.id, cur_chunk.chunks, offset_map) self.vprint(f"Finished examining loop {cur_chunk.id} body for if statements...") + if not chunks_by_id[cur_id].next_chunks: + # We're done! + break + if len(chunks_by_id[cur_id].next_chunks) == 1: # This is just a goto/chunk, move on to the next one. cur_id = chunks_by_id[cur_id].next_chunks[0] @@ -918,9 +925,6 @@ class ByteCodeDecompiler(VerboseOutput): true_jump_point = offset_map[last_action.jump_if_true_offset] false_jump_points = [n for n in cur_chunk.next_chunks if n != true_jump_point] if len(false_jump_points) != 1: - self.vprint(chunks) - self.vprint(f"Jump has {last_action.jump_if_true_offset} which maps to {offset_map[last_action.jump_if_true_offset]}") - self.vprint(f"Jump branches to {cur_chunk.next_chunks}") raise Exception("Logic error!") false_jump_point = false_jump_points[0] @@ -936,7 +940,6 @@ class ByteCodeDecompiler(VerboseOutput): # First, grab all the chunks in this if statement body. true_chunks = self.__gather_chunks(true_jump_point, if_end, chunks_by_id) - self.vprint(true_chunks) # Delete these chunks from our chunk mapping since we're putting them in an if body. for chunk in true_chunks: @@ -951,7 +954,6 @@ class ByteCodeDecompiler(VerboseOutput): # First, grab all the chunks in this if statement body. false_chunks = self.__gather_chunks(false_jump_point, if_end, chunks_by_id) - self.vprint(false_chunks) # Delete these chunks from our chunk mapping since we're putting them in an if body. for chunk in false_chunks: @@ -990,6 +992,50 @@ class ByteCodeDecompiler(VerboseOutput): self.vprint(f"Finished separating if statements out of graph starting at {start_id}") return [c for _, c in chunks_by_id.items()] + def __check_graph(self, start_id: int, chunks: Sequence[ArbitraryCodeChunk]) -> List[ArbitraryCodeChunk]: + # Recursively go through and verify that all entries to the graph have only one link. + # Also, clean up the graph. + chunks_by_id: Dict[int, ArbitraryCodeChunk] = {chunk.id: chunk for chunk in chunks} + new_chunks: List[ArbitraryCodeChunk] = [] + + while True: + cur_chunk = chunks_by_id[start_id] + + # First, clean up any code in chunks that contain other chunks. + if isinstance(cur_chunk, Loop): + # Clean up the loop's chunks + self.vprint(f"Cleaning up graph of Loop {cur_chunk.id}") + cur_chunk.chunks = self.__check_graph(cur_chunk.id, cur_chunk.chunks) + elif isinstance(cur_chunk, IfBody): + # Clean up the if's chunks + if cur_chunk.true_chunks: + self.vprint(f"Cleaning up graph of IfBody {cur_chunk.id} true case") + true_start = self.__get_entry_block(cur_chunk.true_chunks) + cur_chunk.true_chunks = self.__check_graph(true_start, cur_chunk.true_chunks) + if cur_chunk.false_chunks: + self.vprint(f"Cleaning up graph of IfBody {cur_chunk.id} false case") + false_start = self.__get_entry_block(cur_chunk.false_chunks) + cur_chunk.false_chunks = self.__check_graph(false_start, cur_chunk.false_chunks) + + # Now, check to make sure that we have only one exit pointer. + num_exits = len(cur_chunk.next_chunks) + if num_exits > 1: + self.vprint(chunks) + raise Exception("Logic error!") + + # Now, we know this chunk is visited, so we can keep it. + new_chunks.append(cur_chunk) + + # Finally, bail if we've hit the end of the list. + if num_exits == 0: + break + + # Go to the next one! + start_id = cur_chunk.next_chunks[0] + + # Return the tree, stripped of all dead code (most likely just the return sentinel). + return new_chunks + def __decompile(self) -> str: # First, we need to construct a control flow graph. self.vprint("Generating control flow graph...") @@ -1016,6 +1062,10 @@ class ByteCodeDecompiler(VerboseOutput): # At this point, we *should* have a directed graph where there are no # backwards refs and every fork has been identified as an if. This means # we can now walk and recursively generate pseudocode in one pass. + self.vprint("Cleaning up and checking graph...") + chunks_loops_and_ifs = self.__check_graph(start_id, chunks_loops_and_ifs) + + # TODO: Need to go through and start actually converting statements now. self.vprint(chunks_loops_and_ifs) return "TODO"