diff --git a/bemani/format/afp/decompile.py b/bemani/format/afp/decompile.py index 04f95d3..0f1f035 100644 --- a/bemani/format/afp/decompile.py +++ b/bemani/format/afp/decompile.py @@ -515,6 +515,17 @@ class SetMemberStatement(Statement): self.name = name self.valueref = valueref + def code_equiv(self) -> str: + try: + ref = object_ref(self.objectref, "") + name = name_ref(self.name, "") + return f"{ref}.{name}" + except Exception: + # This is not a simple string object reference. + ref = object_ref(self.objectref, "") + name = value_ref(self.name, "") + return f"{ref}[{name}]" + def __repr__(self) -> str: try: ref = object_ref(self.objectref, "") @@ -591,6 +602,9 @@ class StoreRegisterStatement(Statement): self.register = register self.valueref = valueref + def code_equiv(self) -> str: + return self.register.render('') + def __repr__(self) -> str: val = value_ref(self.valueref, "") return f"{self.register.render('')} = {val}" @@ -606,6 +620,9 @@ class SetVariableStatement(Statement): self.name = name self.valueref = valueref + def code_equiv(self) -> str: + return name_ref(self.name, "") + def __repr__(self) -> str: name = name_ref(self.name, "") val = value_ref(self.valueref, "") @@ -623,6 +640,9 @@ class SetLocalStatement(Statement): self.name = name self.valueref = valueref + def code_equiv(self) -> str: + return name_ref(self.name, "") + def __repr__(self) -> str: name = name_ref(self.name, "") val = value_ref(self.valueref, "") @@ -639,6 +659,9 @@ class IfExpr(ConvertedAction): def invert(self) -> "IfExpr": raise NotImplementedError("Not implemented!") + def swap(self) -> "IfExpr": + raise NotImplementedError("Not implemented!") + class IsUndefinedIf(IfExpr): def __init__(self, conditional: Any, negate: bool) -> None: @@ -648,12 +671,15 @@ class IsUndefinedIf(IfExpr): def invert(self) -> "IsUndefinedIf": return IsUndefinedIf(self.conditional, not self.negate) + def swap(self) -> "IsUndefinedIf": + return IsUndefinedIf(self.conditional, self.negate) + def __repr__(self) -> str: val = value_ref(self.conditional, "", parens=True) if self.negate: - return f"if ({val} is not UNDEFINED)" + return f"{val} is not UNDEFINED" else: - return f"if ({val} is UNDEFINED)" + return f"{val} is UNDEFINED" class IsBooleanIf(IfExpr): @@ -664,12 +690,15 @@ class IsBooleanIf(IfExpr): def invert(self) -> "IsBooleanIf": return IsBooleanIf(self.conditional, not self.negate) + def swap(self) -> "IsBooleanIf": + return IsBooleanIf(self.conditional, self.negate) + def __repr__(self) -> str: val = value_ref(self.conditional, "", parens=True) if self.negate: - return f"if (not {val})" + return f"not {val}" else: - return f"if ({val})" + return f"{val}" class TwoParameterIf(IfExpr): @@ -718,10 +747,29 @@ class TwoParameterIf(IfExpr): return TwoParameterIf(self.conditional1, self.STRICT_EQUALS, self.conditional2) raise Exception(f"Cannot invert {self.comp}!") + def swap(self) -> "TwoParameterIf": + if self.comp == self.EQUALS: + return TwoParameterIf(self.conditional2, self.EQUALS, self.conditional1) + if self.comp == self.NOT_EQUALS: + return TwoParameterIf(self.conditional2, self.NOT_EQUALS, self.conditional1) + if self.comp == self.LT: + return TwoParameterIf(self.conditional2, self.GT, self.conditional1) + if self.comp == self.GT: + return TwoParameterIf(self.conditional2, self.LT, self.conditional1) + if self.comp == self.LT_EQUALS: + return TwoParameterIf(self.conditional2, self.GT_EQUALS, self.conditional1) + if self.comp == self.GT_EQUALS: + return TwoParameterIf(self.conditional2, self.LT_EQUALS, self.conditional1) + if self.comp == self.STRICT_EQUALS: + return TwoParameterIf(self.conditional2, self.STRICT_EQUALS, self.conditional1) + if self.comp == self.STRICT_NOT_EQUALS: + return TwoParameterIf(self.conditional2, self.STRICT_NOT_EQUALS, self.conditional1) + raise Exception(f"Cannot swap {self.comp}!") + def __repr__(self) -> str: val1 = value_ref(self.conditional1, "", parens=True) val2 = value_ref(self.conditional2, "", parens=True) - return f"if ({val1} {self.comp} {val2})" + return f"{val1} {self.comp} {val2}" class IfStatement(Statement): @@ -741,7 +789,7 @@ class IfStatement(Statement): if false_entries: return os.linesep.join([ - f"{self.cond} {{", + f"if ({self.cond}) {{", os.linesep.join(true_entries), "} else {", os.linesep.join(false_entries), @@ -749,7 +797,7 @@ class IfStatement(Statement): ]) else: return os.linesep.join([ - f"{self.cond} {{", + f"if ({self.cond}) {{", os.linesep.join(true_entries), "}" ]) @@ -765,7 +813,7 @@ class IfStatement(Statement): if false_entries: return [ - f"{prefix}{self.cond}", + f"{prefix}if ({self.cond})", f"{prefix}{{", *true_entries, f"{prefix}}}", @@ -776,7 +824,7 @@ class IfStatement(Statement): ] else: return [ - f"{prefix}{self.cond}", + f"{prefix}if ({self.cond})", f"{prefix}{{", *true_entries, f"{prefix}}}" @@ -812,6 +860,55 @@ class DoWhileStatement(Statement): ] +class ForStatement(DoWhileStatement): + # Special case of a DoWhileStatement that tracks its own exit condition and increment. + def __init__(self, inc_variable: str, inc_init: Any, cond: IfExpr, inc_assign: Any, body: Sequence[Statement], local: bool = False) -> None: + super().__init__(body) + self.inc_variable = inc_variable + self.inc_init = inc_init + self.cond = cond + self.inc_assign = inc_assign + self.local = local + + def __repr__(self) -> str: + entries: List[str] = [] + for statement in self.body: + entries.extend([f" {s}" for s in str(statement).split(os.linesep)]) + + inc_init = value_ref(self.inc_init, "") + inc_assign = value_ref(self.inc_assign, "") + if self.local: + local = "local " + else: + local = "" + + return os.linesep.join([ + f"for ({local}{self.inc_variable} = {inc_init}; {self.cond}; {self.inc_variable} = {inc_assign}) {{", + os.linesep.join(entries), + "}" + ]) + + def render(self, prefix: str) -> List[str]: + entries: List[str] = [] + for statement in self.body: + entries.extend(statement.render(prefix=prefix + " ")) + + inc_init = value_ref(self.inc_init, "") + inc_assign = value_ref(self.inc_assign, "") + + return [ + f"{prefix}for ({self.inc_variable} = {inc_init}; {self.cond}; {self.inc_variable} = {inc_assign}) {{", + f"{prefix}{{", + *entries, + f"{prefix}}}", + ] + + +class WhileStatement(DoWhileStatement): + # Special case of a DoWhileStatement that tracks its own exit condition. + pass + + class IntermediateIf(ConvertedAction): def __init__(self, parent_action: IfAction, true_statements: Sequence[Statement], false_statements: Sequence[Statement]) -> None: self.parent_action = parent_action @@ -2995,7 +3092,7 @@ class ByteCodeDecompiler(VerboseOutput): return new_statements - def __collapse_identical_labels(self, statements: Sequence[Statement]) -> List[Statement]: + def __collapse_identical_labels(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]: # Go through and find labels that point at gotos, remove them and point the # gotos to those labels at the second gotos. statements = list(statements) @@ -3033,6 +3130,7 @@ class ByteCodeDecompiler(VerboseOutput): return pairs + changed: bool = False while True: redundant_pairs = reduce_labels_and_gotos(find_labels_and_gotos(statements)) if not redundant_pairs: @@ -3051,20 +3149,21 @@ class ByteCodeDecompiler(VerboseOutput): return statement statements = self.__walk(statements, update_gotos) + changed = changed or updated if not updated: break - return statements + return statements, changed def __remove_goto_return(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]: # Go through and find labels that point at returns, convert any gotos pointing # at them to returns. - def find_labels(statements: Sequence[Statement], next_statement: Optional[Statement]) -> Set[int]: + def find_labels(statements: Sequence[Statement], parent_next_statement: Optional[Statement]) -> Set[int]: labels: Set[int] = set() for i in range(len(statements)): cur_statement = statements[i] - next_statement = statements[i + 1] if (i < len(statements) - 1) else next_statement + next_statement = statements[i + 1] if (i < len(statements) - 1) else parent_next_statement if ( isinstance(cur_statement, DefineLabelStatement) and isinstance(next_statement, NullReturnStatement) @@ -3100,12 +3199,12 @@ class ByteCodeDecompiler(VerboseOutput): # Go through and find returns that are on the "last" line. Basically, any # return statement where the next statement is another return statement # or the end of a function. - def find_returns(statements: Sequence[Statement], next_statement: Statement) -> Set[NullReturnStatement]: + def find_returns(statements: Sequence[Statement], parent_next_statement: Statement) -> Set[NullReturnStatement]: returns: Set[NullReturnStatement] = set() for i in range(len(statements)): cur_statement = statements[i] - next_statement = statements[i + 1] if (i < len(statements) - 1) else next_statement + next_statement = statements[i + 1] if (i < len(statements) - 1) else parent_next_statement if ( isinstance(cur_statement, NullReturnStatement) and isinstance(next_statement, NullReturnStatement) @@ -3220,12 +3319,12 @@ class ByteCodeDecompiler(VerboseOutput): # Go through and find continues that are on the "last" line of a while. Basically, any # continue statement where the next statement is another continue statement or the end # of a loop. - def find_continues(statements: Sequence[Statement], next_statement: Statement) -> Set[ContinueStatement]: + def find_continues(statements: Sequence[Statement], parent_next_statement: Statement) -> Set[ContinueStatement]: continues: Set[ContinueStatement] = set() for i in range(len(statements)): cur_statement = statements[i] - next_statement = statements[i + 1] if (i < len(statements) - 1) else next_statement + next_statement = statements[i + 1] if (i < len(statements) - 1) else parent_next_statement if ( isinstance(cur_statement, ContinueStatement) and isinstance(next_statement, ContinueStatement) @@ -3261,17 +3360,193 @@ class ByteCodeDecompiler(VerboseOutput): statements = self.__walk(statements, remove_continues) return statements, updated - def __swap_empty_ifs(self, statements: Sequence[Statement]) -> List[Statement]: - # Go through and find continue statements on the last line of a do-while. + def __is_math(self, expression: Expression, variable: str) -> bool: + if isinstance(expression, ArithmeticExpression): + # Okay, let's see if it is any sort of math. + if expression.op in {"+", "-", "*", "/"}: + # It is, let's see if one of the two sides contains the + # variable we care about. + try: + left = object_ref(expression.left, "") + except Exception: + left = None + try: + right = object_ref(expression.right, "") + except Exception: + right = None + + return left == variable or right == variable + return False + + def __get_increment_variable(self, statement: Statement) -> Optional[str]: + if isinstance(statement, SetMemberStatement): + if isinstance(statement.valueref, Expression): + if self.__is_math(statement.valueref, statement.code_equiv()): + return statement.code_equiv() + if isinstance(statement, StoreRegisterStatement): + if isinstance(statement.valueref, Expression): + if self.__is_math(statement.valueref, statement.code_equiv()): + return statement.code_equiv() + if isinstance(statement, SetVariableStatement): + if isinstance(statement.valueref, Expression): + if self.__is_math(statement.valueref, statement.code_equiv()): + return statement.code_equiv() + if isinstance(statement, SetLocalStatement): + if isinstance(statement.valueref, Expression): + if self.__is_math(statement.valueref, statement.code_equiv()): + return statement.code_equiv() + return None + + def __get_assignment(self, statement: Statement) -> Any: + if isinstance(statement, SetMemberStatement): + return statement.valueref + if isinstance(statement, StoreRegisterStatement): + return statement.valueref + if isinstance(statement, SetVariableStatement): + return statement.valueref + if isinstance(statement, SetLocalStatement): + return statement.valueref + return None + + def __convert_for_loops(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]: + # Convert any do {} while loops that resemble for statements into actual for statements. + # First, we need to hoist any increment to the actual end of the loop in case its in the + # last statement of some if/else condition. This isn't going to be perfectly accurate because + # there can be all sorts of bizarre for statements, but it should be good enough for most + # cases to make better code. + def convert_fors(statements: Sequence[Statement]) -> Tuple[List[Statement], bool]: + new_statements: List[Statement] = [] + updated_statements: Dict[DoWhileStatement, DoWhileStatement] = {} + changed: bool = False + + for i in range(len(statements)): + cur_statement = statements[i] + next_statement = statements[i + 1] if (i < len(statements) - 1) else None + + if isinstance(cur_statement, IfStatement): + # Don't care about this, but we need to recursively walk its children. + cur_statement.true_statements, new_changed = convert_fors(cur_statement.true_statements) + changed = changed or new_changed + + cur_statement.false_statements, new_changed = convert_fors(cur_statement.false_statements) + changed = changed or new_changed + + new_statements.append(cur_statement) + elif isinstance(cur_statement, DoWhileStatement): + # If we addressed this statement, we should use the new statement instead. + for old, new in updated_statements.items(): + if old is cur_statement: + cur_statement = new + break + + # Need to recursively walk through and perform stuff on the body of this. + cur_statement.body, new_changed = convert_fors(cur_statement.body) + changed = changed or new_changed + + new_statements.append(cur_statement) + elif ( + isinstance(cur_statement, (SetMemberStatement, StoreRegisterStatement, SetVariableStatement, SetLocalStatement)) and + isinstance(next_statement, DoWhileStatement) and + not isinstance(next_statement, (ForStatement, WhileStatement)) + ): + # This is a possible conversion that hasn't been converted yet. Let's try to grab + # the increment variable. + if next_statement.body: + inc_variable = self.__get_increment_variable(next_statement.body[-1]) + else: + inc_variable = None + + # Now that we know what's being incremented, let's see if it matches our + # initializer. + inc_assignment = None + if inc_variable is not None and inc_variable != cur_statement.code_equiv(): + # This doesn't match, so let's kill our reference. + inc_variable = None + else: + inc_assignment = self.__get_assignment(next_statement.body[-1]) + + possible_if = None + swap = False + if inc_variable is not None: + # Let's see if the first statement is an if statement with a break. + possible_if = next_statement.body[0] + if isinstance(possible_if, IfStatement): + if len(possible_if.true_statements) == 1 and isinstance(possible_if.true_statements[0], BreakStatement): + # This is possibly a candidate, check the condition's variable usage. + if isinstance(possible_if.cond, IsUndefinedIf): + try: + if_variable = object_ref(possible_if.cond.conditional, "") + except Exception: + if_variable = None + if if_variable != inc_variable: + inc_variable = None + elif isinstance(possible_if.cond, IsBooleanIf): + try: + if_variable = object_ref(possible_if.cond.conditional, "") + except Exception: + if_variable = None + if if_variable != inc_variable: + inc_variable = None + elif isinstance(possible_if.cond, TwoParameterIf): + try: + if_variable1 = object_ref(possible_if.cond.conditional1, "") + except Exception: + if_variable1 = None + try: + if_variable2 = object_ref(possible_if.cond.conditional2, "") + except Exception: + if_variable2 = None + if if_variable1 != inc_variable and if_variable2 != inc_variable: + inc_variable = None + elif if_variable2 == inc_variable: + swap = True + else: + # Shouldn't happen, but let's handle it by bailing. + inc_variable = None + else: + # This isn't a candidate, the conditional isn't a break by itself. + inc_variable = None + else: + # Can't be a for loop, it doesn't have a conditional for breakinng. + inc_variable = None + + if inc_variable is not None: + # This is a for statement. Let's convert it. + cond = cast(IfStatement, possible_if).cond.invert() + updated_statements[next_statement] = ForStatement( + inc_variable, + self.__get_assignment(cur_statement), + cond.swap() if swap else cond, + inc_assignment, + # Drop the increment and the if statement, since we are incorporating them. + next_statement.body[1:-1], + local=isinstance(cur_statement, SetLocalStatement), + ) + else: + new_statements.append(cur_statement) + else: + # Don't care about this one, just append it. + new_statements.append(cur_statement) + + return new_statements, changed + + return convert_fors(statements) + + def __swap_empty_ifs(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]: + # Get rid of empty if statements. If statements with empty if bodies and nonempty + # else bodies will also be swapped. changed: bool = False + updated: bool = False def swap_empty_ifs(statement: Statement) -> Optional[Statement]: nonlocal changed + nonlocal updated if isinstance(statement, IfStatement): if statement.false_statements and (not statement.true_statements): # Swap this, invert the conditional changed = True + updated = True return IfStatement( statement.cond.invert(), statement.false_statements, @@ -3280,6 +3555,7 @@ class ByteCodeDecompiler(VerboseOutput): elif (not statement.true_statements) and (not statement.false_statements): # Drop the if, it has no body. changed = True + updated = True return None return statement @@ -3287,7 +3563,63 @@ class ByteCodeDecompiler(VerboseOutput): changed = False statements = self.__walk(statements, swap_empty_ifs) if not changed: - return statements + return statements, updated + + def __drop_unneeded_else(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]: + # If an if has an else, but the last line of the if is a break/continue/return/throw/goto + # then the else body doesn't need to exist, so hoist it up into the parent. If the false + # statement also has an exit condition, don't drop it for asthetics. + def update_ifs(statements: Sequence[Statement], in_loop: bool) -> Tuple[List[Statement], bool]: + new_statements: List[Statement] = [] + changed: bool = False + + for statement in statements: + if isinstance(statement, IfStatement): + if ( + statement.true_statements and + statement.false_statements and + isinstance( + statement.true_statements[-1], + (BreakStatement, ContinueStatement, ReturnStatement, NullReturnStatement, ThrowStatement, GotoStatement), + ) and + not isinstance( + statement.false_statements[-1], + (BreakStatement, ContinueStatement, ReturnStatement, NullReturnStatement, ThrowStatement, GotoStatement), + ) + ): + # We need to walk both halves still, but once we're done, hoist the false + # statements up to our level. + statement.true_statements, new_changed = update_ifs(statement.true_statements, in_loop) + changed = changed or new_changed + + new_false_statements, new_changed = update_ifs(statement.false_statements, in_loop) + changed = changed or new_changed + statement.false_statements = [] + + # Now, append the if statement, and follow up with the body. + new_statements.append(statement) + new_statements.extend(new_false_statements) + else: + statement.true_statements, new_changed = update_ifs(statement.true_statements, in_loop) + changed = changed or new_changed + + statement.false_statements, new_changed = update_ifs(statement.false_statements, in_loop) + changed = changed or new_changed + + new_statements.append(statement) + elif isinstance(statement, DoWhileStatement): + # Need to recursively walk through and perform stuff on the body of this. + statement.body, new_changed = update_ifs(statement.body, in_loop=True) + changed = changed or new_changed + + new_statements.append(statement) + else: + # Don't care about this one, just append it. + new_statements.append(statement) + + return new_statements, changed + + return update_ifs(statements, in_loop=False) def __verify_balanced_labels(self, statements: Sequence[Statement]) -> None: gotos: Set[int] = set() @@ -3367,28 +3699,24 @@ class ByteCodeDecompiler(VerboseOutput): # Now, let's do some clean-up passes. if self.optimize: - statements = self.__collapse_identical_labels(statements) while True: any_changed = False - - statements, changed = self.__eliminate_useless_continues(statements) - any_changed = any_changed or changed - - statements, changed = self.__eliminate_unused_labels(statements) - any_changed = any_changed or changed - - statements, changed = self.__remove_useless_gotos(statements) - any_changed = any_changed or changed - - statements, changed = self.__remove_goto_return(statements) - any_changed = any_changed or changed - - statements, changed = self.__eliminate_useless_returns(statements) - any_changed = any_changed or changed + for func in [ + self.__collapse_identical_labels, + self.__eliminate_useless_continues, + self.__eliminate_unused_labels, + self.__remove_useless_gotos, + self.__remove_goto_return, + self.__eliminate_useless_returns, + self.__convert_for_loops, + self.__swap_empty_ifs, + self.__drop_unneeded_else, + ]: + statements, changed = func(statements) + any_changed = any_changed or changed if not any_changed: break - statements = self.__swap_empty_ifs(statements) # Let's sanity check the code for a few things that might trip us up. self.__verify_balanced_labels(statements) diff --git a/bemani/tests/test_afp_decompile.py b/bemani/tests/test_afp_decompile.py index fe834bf..2862745 100644 --- a/bemani/tests/test_afp_decompile.py +++ b/bemani/tests/test_afp_decompile.py @@ -846,9 +846,8 @@ class TestAFPDecompile(ExtendedTestCase): " if (not some_condition) {\n" " builtin_StopPlaying()\n" " break\n" - " } else {\n" - " builtin_GotoNextFrame()\n" " }\n" + " builtin_GotoNextFrame()\n" "} while (True)" ]) @@ -873,17 +872,10 @@ class TestAFPDecompile(ExtendedTestCase): AP2Action(109, AP2Action.END), ]) statements = self.__call_decompile(bytecode) - - # TODO: This should be optimized as a for statement. self.assertEqual(self.__equiv(statements), [ - "local i = 0", - "do {\n" - " if (10 <= i) {\n" - " break\n" - " }\n" + "for (local i = 0; i < 10; i = i + 1) {\n" " builtin_GotoNextFrame()\n" - " i = i + 1\n" - "} while (True)" + "}" ]) def test_advanced_for(self) -> None: @@ -916,20 +908,12 @@ class TestAFPDecompile(ExtendedTestCase): AP2Action(115, AP2Action.END), ]) statements = self.__call_decompile(bytecode) - - # TODO: This should be optimized as a for statement. self.assertEqual(self.__equiv(statements), [ - "local i = 0", - "do {\n" - " if (10 <= i) {\n" - " break\n" - " }\n" + "for (local i = 0; i < 10; i = i + 1) {\n" " if (not some_condition) {\n" " builtin_StopPlaying()\n" " break\n" - " } else {\n" - " builtin_GotoNextFrame()\n" - " i = i + 1\n" " }\n" - "} while (True)" + " builtin_GotoNextFrame()\n" + "}" ])