1
0
mirror of synced 2025-01-18 22:24:04 +01:00

Add for loop detection and conversion.

This commit is contained in:
Jennifer Taylor 2021-05-05 00:53:38 +00:00
parent e17f7c2bdb
commit f09fd79a45
2 changed files with 371 additions and 59 deletions

View File

@ -515,6 +515,17 @@ class SetMemberStatement(Statement):
self.name = name
self.valueref = valueref
def code_equiv(self) -> str:
try:
ref = object_ref(self.objectref, "")
name = name_ref(self.name, "")
return f"{ref}.{name}"
except Exception:
# This is not a simple string object reference.
ref = object_ref(self.objectref, "")
name = value_ref(self.name, "")
return f"{ref}[{name}]"
def __repr__(self) -> str:
try:
ref = object_ref(self.objectref, "")
@ -591,6 +602,9 @@ class StoreRegisterStatement(Statement):
self.register = register
self.valueref = valueref
def code_equiv(self) -> str:
return self.register.render('')
def __repr__(self) -> str:
val = value_ref(self.valueref, "")
return f"{self.register.render('')} = {val}"
@ -606,6 +620,9 @@ class SetVariableStatement(Statement):
self.name = name
self.valueref = valueref
def code_equiv(self) -> str:
return name_ref(self.name, "")
def __repr__(self) -> str:
name = name_ref(self.name, "")
val = value_ref(self.valueref, "")
@ -623,6 +640,9 @@ class SetLocalStatement(Statement):
self.name = name
self.valueref = valueref
def code_equiv(self) -> str:
return name_ref(self.name, "")
def __repr__(self) -> str:
name = name_ref(self.name, "")
val = value_ref(self.valueref, "")
@ -639,6 +659,9 @@ class IfExpr(ConvertedAction):
def invert(self) -> "IfExpr":
raise NotImplementedError("Not implemented!")
def swap(self) -> "IfExpr":
raise NotImplementedError("Not implemented!")
class IsUndefinedIf(IfExpr):
def __init__(self, conditional: Any, negate: bool) -> None:
@ -648,12 +671,15 @@ class IsUndefinedIf(IfExpr):
def invert(self) -> "IsUndefinedIf":
return IsUndefinedIf(self.conditional, not self.negate)
def swap(self) -> "IsUndefinedIf":
return IsUndefinedIf(self.conditional, self.negate)
def __repr__(self) -> str:
val = value_ref(self.conditional, "", parens=True)
if self.negate:
return f"if ({val} is not UNDEFINED)"
return f"{val} is not UNDEFINED"
else:
return f"if ({val} is UNDEFINED)"
return f"{val} is UNDEFINED"
class IsBooleanIf(IfExpr):
@ -664,12 +690,15 @@ class IsBooleanIf(IfExpr):
def invert(self) -> "IsBooleanIf":
return IsBooleanIf(self.conditional, not self.negate)
def swap(self) -> "IsBooleanIf":
return IsBooleanIf(self.conditional, self.negate)
def __repr__(self) -> str:
val = value_ref(self.conditional, "", parens=True)
if self.negate:
return f"if (not {val})"
return f"not {val}"
else:
return f"if ({val})"
return f"{val}"
class TwoParameterIf(IfExpr):
@ -718,10 +747,29 @@ class TwoParameterIf(IfExpr):
return TwoParameterIf(self.conditional1, self.STRICT_EQUALS, self.conditional2)
raise Exception(f"Cannot invert {self.comp}!")
def swap(self) -> "TwoParameterIf":
if self.comp == self.EQUALS:
return TwoParameterIf(self.conditional2, self.EQUALS, self.conditional1)
if self.comp == self.NOT_EQUALS:
return TwoParameterIf(self.conditional2, self.NOT_EQUALS, self.conditional1)
if self.comp == self.LT:
return TwoParameterIf(self.conditional2, self.GT, self.conditional1)
if self.comp == self.GT:
return TwoParameterIf(self.conditional2, self.LT, self.conditional1)
if self.comp == self.LT_EQUALS:
return TwoParameterIf(self.conditional2, self.GT_EQUALS, self.conditional1)
if self.comp == self.GT_EQUALS:
return TwoParameterIf(self.conditional2, self.LT_EQUALS, self.conditional1)
if self.comp == self.STRICT_EQUALS:
return TwoParameterIf(self.conditional2, self.STRICT_EQUALS, self.conditional1)
if self.comp == self.STRICT_NOT_EQUALS:
return TwoParameterIf(self.conditional2, self.STRICT_NOT_EQUALS, self.conditional1)
raise Exception(f"Cannot swap {self.comp}!")
def __repr__(self) -> str:
val1 = value_ref(self.conditional1, "", parens=True)
val2 = value_ref(self.conditional2, "", parens=True)
return f"if ({val1} {self.comp} {val2})"
return f"{val1} {self.comp} {val2}"
class IfStatement(Statement):
@ -741,7 +789,7 @@ class IfStatement(Statement):
if false_entries:
return os.linesep.join([
f"{self.cond} {{",
f"if ({self.cond}) {{",
os.linesep.join(true_entries),
"} else {",
os.linesep.join(false_entries),
@ -749,7 +797,7 @@ class IfStatement(Statement):
])
else:
return os.linesep.join([
f"{self.cond} {{",
f"if ({self.cond}) {{",
os.linesep.join(true_entries),
"}"
])
@ -765,7 +813,7 @@ class IfStatement(Statement):
if false_entries:
return [
f"{prefix}{self.cond}",
f"{prefix}if ({self.cond})",
f"{prefix}{{",
*true_entries,
f"{prefix}}}",
@ -776,7 +824,7 @@ class IfStatement(Statement):
]
else:
return [
f"{prefix}{self.cond}",
f"{prefix}if ({self.cond})",
f"{prefix}{{",
*true_entries,
f"{prefix}}}"
@ -812,6 +860,55 @@ class DoWhileStatement(Statement):
]
class ForStatement(DoWhileStatement):
# Special case of a DoWhileStatement that tracks its own exit condition and increment.
def __init__(self, inc_variable: str, inc_init: Any, cond: IfExpr, inc_assign: Any, body: Sequence[Statement], local: bool = False) -> None:
super().__init__(body)
self.inc_variable = inc_variable
self.inc_init = inc_init
self.cond = cond
self.inc_assign = inc_assign
self.local = local
def __repr__(self) -> str:
entries: List[str] = []
for statement in self.body:
entries.extend([f" {s}" for s in str(statement).split(os.linesep)])
inc_init = value_ref(self.inc_init, "")
inc_assign = value_ref(self.inc_assign, "")
if self.local:
local = "local "
else:
local = ""
return os.linesep.join([
f"for ({local}{self.inc_variable} = {inc_init}; {self.cond}; {self.inc_variable} = {inc_assign}) {{",
os.linesep.join(entries),
"}"
])
def render(self, prefix: str) -> List[str]:
entries: List[str] = []
for statement in self.body:
entries.extend(statement.render(prefix=prefix + " "))
inc_init = value_ref(self.inc_init, "")
inc_assign = value_ref(self.inc_assign, "")
return [
f"{prefix}for ({self.inc_variable} = {inc_init}; {self.cond}; {self.inc_variable} = {inc_assign}) {{",
f"{prefix}{{",
*entries,
f"{prefix}}}",
]
class WhileStatement(DoWhileStatement):
# Special case of a DoWhileStatement that tracks its own exit condition.
pass
class IntermediateIf(ConvertedAction):
def __init__(self, parent_action: IfAction, true_statements: Sequence[Statement], false_statements: Sequence[Statement]) -> None:
self.parent_action = parent_action
@ -2995,7 +3092,7 @@ class ByteCodeDecompiler(VerboseOutput):
return new_statements
def __collapse_identical_labels(self, statements: Sequence[Statement]) -> List[Statement]:
def __collapse_identical_labels(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
# Go through and find labels that point at gotos, remove them and point the
# gotos to those labels at the second gotos.
statements = list(statements)
@ -3033,6 +3130,7 @@ class ByteCodeDecompiler(VerboseOutput):
return pairs
changed: bool = False
while True:
redundant_pairs = reduce_labels_and_gotos(find_labels_and_gotos(statements))
if not redundant_pairs:
@ -3051,20 +3149,21 @@ class ByteCodeDecompiler(VerboseOutput):
return statement
statements = self.__walk(statements, update_gotos)
changed = changed or updated
if not updated:
break
return statements
return statements, changed
def __remove_goto_return(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
# Go through and find labels that point at returns, convert any gotos pointing
# at them to returns.
def find_labels(statements: Sequence[Statement], next_statement: Optional[Statement]) -> Set[int]:
def find_labels(statements: Sequence[Statement], parent_next_statement: Optional[Statement]) -> Set[int]:
labels: Set[int] = set()
for i in range(len(statements)):
cur_statement = statements[i]
next_statement = statements[i + 1] if (i < len(statements) - 1) else next_statement
next_statement = statements[i + 1] if (i < len(statements) - 1) else parent_next_statement
if (
isinstance(cur_statement, DefineLabelStatement) and
isinstance(next_statement, NullReturnStatement)
@ -3100,12 +3199,12 @@ class ByteCodeDecompiler(VerboseOutput):
# Go through and find returns that are on the "last" line. Basically, any
# return statement where the next statement is another return statement
# or the end of a function.
def find_returns(statements: Sequence[Statement], next_statement: Statement) -> Set[NullReturnStatement]:
def find_returns(statements: Sequence[Statement], parent_next_statement: Statement) -> Set[NullReturnStatement]:
returns: Set[NullReturnStatement] = set()
for i in range(len(statements)):
cur_statement = statements[i]
next_statement = statements[i + 1] if (i < len(statements) - 1) else next_statement
next_statement = statements[i + 1] if (i < len(statements) - 1) else parent_next_statement
if (
isinstance(cur_statement, NullReturnStatement) and
isinstance(next_statement, NullReturnStatement)
@ -3220,12 +3319,12 @@ class ByteCodeDecompiler(VerboseOutput):
# Go through and find continues that are on the "last" line of a while. Basically, any
# continue statement where the next statement is another continue statement or the end
# of a loop.
def find_continues(statements: Sequence[Statement], next_statement: Statement) -> Set[ContinueStatement]:
def find_continues(statements: Sequence[Statement], parent_next_statement: Statement) -> Set[ContinueStatement]:
continues: Set[ContinueStatement] = set()
for i in range(len(statements)):
cur_statement = statements[i]
next_statement = statements[i + 1] if (i < len(statements) - 1) else next_statement
next_statement = statements[i + 1] if (i < len(statements) - 1) else parent_next_statement
if (
isinstance(cur_statement, ContinueStatement) and
isinstance(next_statement, ContinueStatement)
@ -3261,17 +3360,193 @@ class ByteCodeDecompiler(VerboseOutput):
statements = self.__walk(statements, remove_continues)
return statements, updated
def __swap_empty_ifs(self, statements: Sequence[Statement]) -> List[Statement]:
# Go through and find continue statements on the last line of a do-while.
def __is_math(self, expression: Expression, variable: str) -> bool:
if isinstance(expression, ArithmeticExpression):
# Okay, let's see if it is any sort of math.
if expression.op in {"+", "-", "*", "/"}:
# It is, let's see if one of the two sides contains the
# variable we care about.
try:
left = object_ref(expression.left, "")
except Exception:
left = None
try:
right = object_ref(expression.right, "")
except Exception:
right = None
return left == variable or right == variable
return False
def __get_increment_variable(self, statement: Statement) -> Optional[str]:
if isinstance(statement, SetMemberStatement):
if isinstance(statement.valueref, Expression):
if self.__is_math(statement.valueref, statement.code_equiv()):
return statement.code_equiv()
if isinstance(statement, StoreRegisterStatement):
if isinstance(statement.valueref, Expression):
if self.__is_math(statement.valueref, statement.code_equiv()):
return statement.code_equiv()
if isinstance(statement, SetVariableStatement):
if isinstance(statement.valueref, Expression):
if self.__is_math(statement.valueref, statement.code_equiv()):
return statement.code_equiv()
if isinstance(statement, SetLocalStatement):
if isinstance(statement.valueref, Expression):
if self.__is_math(statement.valueref, statement.code_equiv()):
return statement.code_equiv()
return None
def __get_assignment(self, statement: Statement) -> Any:
if isinstance(statement, SetMemberStatement):
return statement.valueref
if isinstance(statement, StoreRegisterStatement):
return statement.valueref
if isinstance(statement, SetVariableStatement):
return statement.valueref
if isinstance(statement, SetLocalStatement):
return statement.valueref
return None
def __convert_for_loops(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
# Convert any do {} while loops that resemble for statements into actual for statements.
# First, we need to hoist any increment to the actual end of the loop in case its in the
# last statement of some if/else condition. This isn't going to be perfectly accurate because
# there can be all sorts of bizarre for statements, but it should be good enough for most
# cases to make better code.
def convert_fors(statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
new_statements: List[Statement] = []
updated_statements: Dict[DoWhileStatement, DoWhileStatement] = {}
changed: bool = False
for i in range(len(statements)):
cur_statement = statements[i]
next_statement = statements[i + 1] if (i < len(statements) - 1) else None
if isinstance(cur_statement, IfStatement):
# Don't care about this, but we need to recursively walk its children.
cur_statement.true_statements, new_changed = convert_fors(cur_statement.true_statements)
changed = changed or new_changed
cur_statement.false_statements, new_changed = convert_fors(cur_statement.false_statements)
changed = changed or new_changed
new_statements.append(cur_statement)
elif isinstance(cur_statement, DoWhileStatement):
# If we addressed this statement, we should use the new statement instead.
for old, new in updated_statements.items():
if old is cur_statement:
cur_statement = new
break
# Need to recursively walk through and perform stuff on the body of this.
cur_statement.body, new_changed = convert_fors(cur_statement.body)
changed = changed or new_changed
new_statements.append(cur_statement)
elif (
isinstance(cur_statement, (SetMemberStatement, StoreRegisterStatement, SetVariableStatement, SetLocalStatement)) and
isinstance(next_statement, DoWhileStatement) and
not isinstance(next_statement, (ForStatement, WhileStatement))
):
# This is a possible conversion that hasn't been converted yet. Let's try to grab
# the increment variable.
if next_statement.body:
inc_variable = self.__get_increment_variable(next_statement.body[-1])
else:
inc_variable = None
# Now that we know what's being incremented, let's see if it matches our
# initializer.
inc_assignment = None
if inc_variable is not None and inc_variable != cur_statement.code_equiv():
# This doesn't match, so let's kill our reference.
inc_variable = None
else:
inc_assignment = self.__get_assignment(next_statement.body[-1])
possible_if = None
swap = False
if inc_variable is not None:
# Let's see if the first statement is an if statement with a break.
possible_if = next_statement.body[0]
if isinstance(possible_if, IfStatement):
if len(possible_if.true_statements) == 1 and isinstance(possible_if.true_statements[0], BreakStatement):
# This is possibly a candidate, check the condition's variable usage.
if isinstance(possible_if.cond, IsUndefinedIf):
try:
if_variable = object_ref(possible_if.cond.conditional, "")
except Exception:
if_variable = None
if if_variable != inc_variable:
inc_variable = None
elif isinstance(possible_if.cond, IsBooleanIf):
try:
if_variable = object_ref(possible_if.cond.conditional, "")
except Exception:
if_variable = None
if if_variable != inc_variable:
inc_variable = None
elif isinstance(possible_if.cond, TwoParameterIf):
try:
if_variable1 = object_ref(possible_if.cond.conditional1, "")
except Exception:
if_variable1 = None
try:
if_variable2 = object_ref(possible_if.cond.conditional2, "")
except Exception:
if_variable2 = None
if if_variable1 != inc_variable and if_variable2 != inc_variable:
inc_variable = None
elif if_variable2 == inc_variable:
swap = True
else:
# Shouldn't happen, but let's handle it by bailing.
inc_variable = None
else:
# This isn't a candidate, the conditional isn't a break by itself.
inc_variable = None
else:
# Can't be a for loop, it doesn't have a conditional for breakinng.
inc_variable = None
if inc_variable is not None:
# This is a for statement. Let's convert it.
cond = cast(IfStatement, possible_if).cond.invert()
updated_statements[next_statement] = ForStatement(
inc_variable,
self.__get_assignment(cur_statement),
cond.swap() if swap else cond,
inc_assignment,
# Drop the increment and the if statement, since we are incorporating them.
next_statement.body[1:-1],
local=isinstance(cur_statement, SetLocalStatement),
)
else:
new_statements.append(cur_statement)
else:
# Don't care about this one, just append it.
new_statements.append(cur_statement)
return new_statements, changed
return convert_fors(statements)
def __swap_empty_ifs(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
# Get rid of empty if statements. If statements with empty if bodies and nonempty
# else bodies will also be swapped.
changed: bool = False
updated: bool = False
def swap_empty_ifs(statement: Statement) -> Optional[Statement]:
nonlocal changed
nonlocal updated
if isinstance(statement, IfStatement):
if statement.false_statements and (not statement.true_statements):
# Swap this, invert the conditional
changed = True
updated = True
return IfStatement(
statement.cond.invert(),
statement.false_statements,
@ -3280,6 +3555,7 @@ class ByteCodeDecompiler(VerboseOutput):
elif (not statement.true_statements) and (not statement.false_statements):
# Drop the if, it has no body.
changed = True
updated = True
return None
return statement
@ -3287,7 +3563,63 @@ class ByteCodeDecompiler(VerboseOutput):
changed = False
statements = self.__walk(statements, swap_empty_ifs)
if not changed:
return statements
return statements, updated
def __drop_unneeded_else(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
# If an if has an else, but the last line of the if is a break/continue/return/throw/goto
# then the else body doesn't need to exist, so hoist it up into the parent. If the false
# statement also has an exit condition, don't drop it for asthetics.
def update_ifs(statements: Sequence[Statement], in_loop: bool) -> Tuple[List[Statement], bool]:
new_statements: List[Statement] = []
changed: bool = False
for statement in statements:
if isinstance(statement, IfStatement):
if (
statement.true_statements and
statement.false_statements and
isinstance(
statement.true_statements[-1],
(BreakStatement, ContinueStatement, ReturnStatement, NullReturnStatement, ThrowStatement, GotoStatement),
) and
not isinstance(
statement.false_statements[-1],
(BreakStatement, ContinueStatement, ReturnStatement, NullReturnStatement, ThrowStatement, GotoStatement),
)
):
# We need to walk both halves still, but once we're done, hoist the false
# statements up to our level.
statement.true_statements, new_changed = update_ifs(statement.true_statements, in_loop)
changed = changed or new_changed
new_false_statements, new_changed = update_ifs(statement.false_statements, in_loop)
changed = changed or new_changed
statement.false_statements = []
# Now, append the if statement, and follow up with the body.
new_statements.append(statement)
new_statements.extend(new_false_statements)
else:
statement.true_statements, new_changed = update_ifs(statement.true_statements, in_loop)
changed = changed or new_changed
statement.false_statements, new_changed = update_ifs(statement.false_statements, in_loop)
changed = changed or new_changed
new_statements.append(statement)
elif isinstance(statement, DoWhileStatement):
# Need to recursively walk through and perform stuff on the body of this.
statement.body, new_changed = update_ifs(statement.body, in_loop=True)
changed = changed or new_changed
new_statements.append(statement)
else:
# Don't care about this one, just append it.
new_statements.append(statement)
return new_statements, changed
return update_ifs(statements, in_loop=False)
def __verify_balanced_labels(self, statements: Sequence[Statement]) -> None:
gotos: Set[int] = set()
@ -3367,28 +3699,24 @@ class ByteCodeDecompiler(VerboseOutput):
# Now, let's do some clean-up passes.
if self.optimize:
statements = self.__collapse_identical_labels(statements)
while True:
any_changed = False
statements, changed = self.__eliminate_useless_continues(statements)
any_changed = any_changed or changed
statements, changed = self.__eliminate_unused_labels(statements)
any_changed = any_changed or changed
statements, changed = self.__remove_useless_gotos(statements)
any_changed = any_changed or changed
statements, changed = self.__remove_goto_return(statements)
any_changed = any_changed or changed
statements, changed = self.__eliminate_useless_returns(statements)
any_changed = any_changed or changed
for func in [
self.__collapse_identical_labels,
self.__eliminate_useless_continues,
self.__eliminate_unused_labels,
self.__remove_useless_gotos,
self.__remove_goto_return,
self.__eliminate_useless_returns,
self.__convert_for_loops,
self.__swap_empty_ifs,
self.__drop_unneeded_else,
]:
statements, changed = func(statements)
any_changed = any_changed or changed
if not any_changed:
break
statements = self.__swap_empty_ifs(statements)
# Let's sanity check the code for a few things that might trip us up.
self.__verify_balanced_labels(statements)

View File

@ -846,9 +846,8 @@ class TestAFPDecompile(ExtendedTestCase):
" if (not some_condition) {\n"
" builtin_StopPlaying()\n"
" break\n"
" } else {\n"
" builtin_GotoNextFrame()\n"
" }\n"
" builtin_GotoNextFrame()\n"
"} while (True)"
])
@ -873,17 +872,10 @@ class TestAFPDecompile(ExtendedTestCase):
AP2Action(109, AP2Action.END),
])
statements = self.__call_decompile(bytecode)
# TODO: This should be optimized as a for statement.
self.assertEqual(self.__equiv(statements), [
"local i = 0",
"do {\n"
" if (10 <= i) {\n"
" break\n"
" }\n"
"for (local i = 0; i < 10; i = i + 1) {\n"
" builtin_GotoNextFrame()\n"
" i = i + 1\n"
"} while (True)"
"}"
])
def test_advanced_for(self) -> None:
@ -916,20 +908,12 @@ class TestAFPDecompile(ExtendedTestCase):
AP2Action(115, AP2Action.END),
])
statements = self.__call_decompile(bytecode)
# TODO: This should be optimized as a for statement.
self.assertEqual(self.__equiv(statements), [
"local i = 0",
"do {\n"
" if (10 <= i) {\n"
" break\n"
" }\n"
"for (local i = 0; i < 10; i = i + 1) {\n"
" if (not some_condition) {\n"
" builtin_StopPlaying()\n"
" break\n"
" } else {\n"
" builtin_GotoNextFrame()\n"
" i = i + 1\n"
" }\n"
"} while (True)"
" builtin_GotoNextFrame()\n"
"}"
])