diff --git a/bemani/format/afp/decompile.py b/bemani/format/afp/decompile.py index 92d274c..04f95d3 100644 --- a/bemani/format/afp/decompile.py +++ b/bemani/format/afp/decompile.py @@ -795,8 +795,7 @@ class DoWhileStatement(Statement): return os.linesep.join([ "do {", os.linesep.join(entries), - "}", - "while(True);" + "} while (True)" ]) def render(self, prefix: str) -> List[str]: @@ -809,7 +808,7 @@ class DoWhileStatement(Statement): f"{prefix}{{", *entries, f"{prefix}}}", - f"{prefix}while(True);", + f"{prefix}while (True);", ] @@ -3099,7 +3098,7 @@ class ByteCodeDecompiler(VerboseOutput): def __eliminate_useless_returns(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]: # Go through and find returns that are on the "last" line. Basically, any - # return statement where the next return statement is another return statement + # return statement where the next statement is another return statement # or the end of a function. def find_returns(statements: Sequence[Statement], next_statement: Statement) -> Set[NullReturnStatement]: returns: Set[NullReturnStatement] = set() @@ -3218,20 +3217,49 @@ class ByteCodeDecompiler(VerboseOutput): return self.__walk(statements, remove_label), changed def __eliminate_useless_continues(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]: - # Go through and find continue statements on the last line of a do-while. - changed: bool = False + # Go through and find continues that are on the "last" line of a while. Basically, any + # continue statement where the next statement is another continue statement or the end + # of a loop. + def find_continues(statements: Sequence[Statement], next_statement: Statement) -> Set[ContinueStatement]: + continues: Set[ContinueStatement] = set() - def remove_continue(statement: Statement) -> Optional[Statement]: - nonlocal changed + for i in range(len(statements)): + cur_statement = statements[i] + next_statement = statements[i + 1] if (i < len(statements) - 1) else next_statement + if ( + isinstance(cur_statement, ContinueStatement) and + isinstance(next_statement, ContinueStatement) + ): + continues.add(cur_statement) - if isinstance(statement, DoWhileStatement): - if statement.body and isinstance(statement.body[-1], ContinueStatement): - changed = True - statement.body.pop() + elif isinstance(cur_statement, DoWhileStatement): + # Clever hack, where we pretend the next value after the loop is a continue, + # because hitting the bottom of a loop is actually a continue. + continues.update(find_continues(cur_statement.body, ContinueStatement())) + + elif isinstance(cur_statement, IfStatement): + continues.update(find_continues(cur_statement.true_statements, next_statement)) + continues.update(find_continues(cur_statement.false_statements, next_statement)) + + return continues + + # Instead of an empty next statement, make up a return because that's what + # falling off the end of execution means. + continues = find_continues(statements, NullReturnStatement()) + + updated: bool = False + + def remove_continues(statement: Statement) -> Statement: + nonlocal updated + + for removable in continues: + if removable is statement: + updated = True + return None return statement - statements = self.__walk(statements, remove_continue) - return statements, changed + statements = self.__walk(statements, remove_continues) + return statements, updated def __swap_empty_ifs(self, statements: Sequence[Statement]) -> List[Statement]: # Go through and find continue statements on the last line of a do-while. diff --git a/bemani/tests/test_afp_decompile.py b/bemani/tests/test_afp_decompile.py index 2892ad2..fe834bf 100644 --- a/bemani/tests/test_afp_decompile.py +++ b/bemani/tests/test_afp_decompile.py @@ -3,7 +3,7 @@ import unittest from typing import Dict, List, Sequence, Tuple, Union from bemani.tests.helpers import ExtendedTestCase -from bemani.format.afp.types.ap2 import AP2Action, IfAction, JumpAction, PushAction, Register +from bemani.format.afp.types.ap2 import AP2Action, IfAction, JumpAction, PushAction, AddNumVariableAction, Register from bemani.format.afp.decompile import BitVector, ByteCode, ByteCodeChunk, ControlFlow, ByteCodeDecompiler, Statement @@ -779,3 +779,157 @@ class TestAFPDecompile(ExtendedTestCase): "builtin_StartPlaying()", "goto label_4", ]) + + def test_basic_while(self) -> None: + # A basic while statement. + bytecode = self.__make_bytecode([ + # Define exit condition variable. + PushAction(100, ["finished", False]), + AP2Action(101, AP2Action.DEFINE_LOCAL), + # Check exit condition. + PushAction(102, ["finished"]), + AP2Action(103, AP2Action.GET_VARIABLE), + IfAction(104, IfAction.IS_TRUE, 107), + # Loop code. + AP2Action(105, AP2Action.NEXT_FRAME), + # Loop finished jump back to beginning. + JumpAction(106, 102), + # End of loop. + AP2Action(107, AP2Action.END), + ]) + statements = self.__call_decompile(bytecode) + + # TODO: This should be optimized as a while statement. + self.assertEqual(self.__equiv(statements), [ + "local finished = False", + "do {\n" + " if (finished) {\n" + " break\n" + " }\n" + " builtin_GotoNextFrame()\n" + "} while (True)" + ]) + + def test_advanced_while(self) -> None: + # A basic while statement. + bytecode = self.__make_bytecode([ + # Define exit condition variable. + PushAction(100, ["finished", False]), + AP2Action(101, AP2Action.DEFINE_LOCAL), + # Check exit condition. + PushAction(102, ["finished"]), + AP2Action(103, AP2Action.GET_VARIABLE), + IfAction(104, IfAction.IS_TRUE, 112), + # Loop code with a continue statement. + PushAction(105, ["some_condition"]), + AP2Action(106, AP2Action.GET_VARIABLE), + IfAction(107, IfAction.IS_FALSE, 110), + AP2Action(108, AP2Action.NEXT_FRAME), + # Continue statement. + JumpAction(109, 102), + # Exit early. + AP2Action(110, AP2Action.STOP), + # Break statement. + JumpAction(111, 112), + # End of loop. + AP2Action(112, AP2Action.END), + ]) + statements = self.__call_decompile(bytecode) + + # TODO: This should be optimized as a while statement. + self.assertEqual(self.__equiv(statements), [ + "local finished = False", + "do {\n" + " if (finished) {\n" + " break\n" + " }\n" + " if (not some_condition) {\n" + " builtin_StopPlaying()\n" + " break\n" + " } else {\n" + " builtin_GotoNextFrame()\n" + " }\n" + "} while (True)" + ]) + + def test_basic_for(self) -> None: + # A basic for statement. + bytecode = self.__make_bytecode([ + # Define exit condition variable. + PushAction(100, ["i", 0]), + AP2Action(101, AP2Action.DEFINE_LOCAL), + # Check exit condition. + PushAction(102, [10, "i"]), + AP2Action(103, AP2Action.GET_VARIABLE), + IfAction(104, IfAction.LT_EQUALS, 109), + # Loop code. + AP2Action(105, AP2Action.NEXT_FRAME), + # Increment, also the continue point. + PushAction(106, ["i"]), + AddNumVariableAction(107, 1), + # Loop finished jump back to beginning. + JumpAction(108, 102), + # End of loop. + AP2Action(109, AP2Action.END), + ]) + statements = self.__call_decompile(bytecode) + + # TODO: This should be optimized as a for statement. + self.assertEqual(self.__equiv(statements), [ + "local i = 0", + "do {\n" + " if (10 <= i) {\n" + " break\n" + " }\n" + " builtin_GotoNextFrame()\n" + " i = i + 1\n" + "} while (True)" + ]) + + def test_advanced_for(self) -> None: + # A basic for statement. + bytecode = self.__make_bytecode([ + # Define exit condition variable. + PushAction(100, ["i", 0]), + AP2Action(101, AP2Action.DEFINE_LOCAL), + # Check exit condition. + PushAction(102, [10, "i"]), + AP2Action(103, AP2Action.GET_VARIABLE), + IfAction(104, IfAction.LT_EQUALS, 115), + # Loop code with a continue statement. + PushAction(105, ["some_condition"]), + AP2Action(106, AP2Action.GET_VARIABLE), + IfAction(107, IfAction.IS_FALSE, 110), + AP2Action(108, AP2Action.NEXT_FRAME), + # Continue statement. + JumpAction(109, 112), + # Exit early. + AP2Action(110, AP2Action.STOP), + # Break statement. + JumpAction(111, 115), + # Increment, also the continue point. + PushAction(112, ["i"]), + AddNumVariableAction(113, 1), + # Loop finished jump back to beginning. + JumpAction(114, 102), + # End of loop. + AP2Action(115, AP2Action.END), + ]) + statements = self.__call_decompile(bytecode) + + # TODO: This should be optimized as a for statement. + self.assertEqual(self.__equiv(statements), [ + "local i = 0", + "do {\n" + " if (10 <= i) {\n" + " break\n" + " }\n" + " if (not some_condition) {\n" + " builtin_StopPlaying()\n" + " break\n" + " } else {\n" + " builtin_GotoNextFrame()\n" + " i = i + 1\n" + " }\n" + "} while (True)" + ])