1
0
mirror of synced 2024-11-27 23:50:47 +01:00

Introduce switch statements and use those to make decompilation better.

This commit is contained in:
Jennifer Taylor 2021-07-15 00:13:02 +00:00
parent d5a8fbfb4c
commit 599825332a
4 changed files with 362 additions and 29 deletions

View File

@ -23,6 +23,7 @@ from .types import (
GetTimeFunctionCall,
GetPathFunctionCall,
Variable,
Register,
Member,
NewFunction,
NewObject,
@ -46,6 +47,8 @@ from .types import (
RemoveSpriteStatement,
DebugTraceStatement,
IfStatement,
SwitchStatement,
SwitchCase,
DoWhileStatement,
WhileStatement,
ForStatement,
@ -2514,6 +2517,14 @@ class ByteCodeDecompiler(VerboseOutput):
new_statement.true_statements = self.__walk(new_statement.true_statements, do)
new_statement.false_statements = self.__walk(new_statement.false_statements, do)
new_statements.append(new_statement)
elif isinstance(new_statement, SwitchStatement):
new_statement.cases = [
SwitchCase(
case.const,
self.__walk(case.statements, do),
) for case in new_statement.cases
]
new_statements.append(new_statement)
elif new_statement:
new_statements.append(new_statement)
@ -2543,6 +2554,10 @@ class ByteCodeDecompiler(VerboseOutput):
label_and_goto.update(find_labels_and_gotos(cur_statement.true_statements))
label_and_goto.update(find_labels_and_gotos(cur_statement.false_statements))
elif isinstance(cur_statement, SwitchStatement):
for case in cur_statement.cases:
label_and_goto.update(find_labels_and_gotos(case.statements))
return label_and_goto
def reduce_labels_and_gotos(pairs: Dict[int, int]) -> Dict[int, int]:
@ -2604,6 +2619,10 @@ class ByteCodeDecompiler(VerboseOutput):
labels.update(find_labels(cur_statement.true_statements, next_statement))
labels.update(find_labels(cur_statement.false_statements, next_statement))
elif isinstance(cur_statement, SwitchStatement):
for case in cur_statement.cases:
labels.update(find_labels(case.statements, next_statement))
return labels
labels = find_labels(statements, None)
@ -2645,6 +2664,10 @@ class ByteCodeDecompiler(VerboseOutput):
returns.update(find_returns(cur_statement.true_statements, next_statement))
returns.update(find_returns(cur_statement.false_statements, next_statement))
elif isinstance(cur_statement, SwitchStatement):
for case in cur_statement.cases:
returns.update(find_returns(case.statements, next_statement))
return returns
# Instead of an empty next statement, make up a return so we catch anything
@ -2696,6 +2719,12 @@ class ByteCodeDecompiler(VerboseOutput):
gotos.extend(find_goto_next_line(cur_statement.true_statements, next_statement))
gotos.extend(find_goto_next_line(cur_statement.false_statements, next_statement))
elif isinstance(cur_statement, SwitchStatement):
# Switch cases do not "flow" into the next line, they flow into the next switch
# case.
for case in cur_statement.cases:
gotos.extend(find_goto_next_line(case.statements, NopStatement()))
return gotos
# Whether we made at least one substitution.
@ -2767,6 +2796,10 @@ class ByteCodeDecompiler(VerboseOutput):
continues.update(find_continues(cur_statement.true_statements, next_statement))
continues.update(find_continues(cur_statement.false_statements, next_statement))
elif isinstance(cur_statement, SwitchStatement):
for case in cur_statement.cases:
continues.update(find_continues(case.statements, next_statement))
return continues
# Instead of an empty next statement, make up a return because that's what
@ -2906,6 +2939,13 @@ class ByteCodeDecompiler(VerboseOutput):
changed = changed or new_changed
new_statements.append(cur_statement)
elif isinstance(cur_statement, SwitchStatement):
for case in cur_statement.cases:
case.statements, new_changed = convert_loops(case.statements)
changed = changed or new_changed
new_statements.append(cur_statement)
elif isinstance(cur_statement, DoWhileStatement):
# If we addressed this statement, we should use the new statement instead.
for old, new in updated_statements.items():
@ -3060,7 +3100,7 @@ class ByteCodeDecompiler(VerboseOutput):
# If an if has an else, but the last line of the if is a break/continue/return/throw/goto
# then the else body doesn't need to exist, so hoist it up into the parent. If the false
# statement also has an exit condition, don't drop it for asthetics.
def update_ifs(statements: Sequence[Statement], in_loop: bool) -> Tuple[List[Statement], bool]:
def update_ifs(statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
new_statements: List[Statement] = []
changed: bool = False
@ -3080,10 +3120,10 @@ class ByteCodeDecompiler(VerboseOutput):
):
# We need to walk both halves still, but once we're done, hoist the false
# statements up to our level.
statement.true_statements, new_changed = update_ifs(statement.true_statements, in_loop)
statement.true_statements, new_changed = update_ifs(statement.true_statements)
changed = changed or new_changed
new_false_statements, new_changed = update_ifs(statement.false_statements, in_loop)
new_false_statements, new_changed = update_ifs(statement.false_statements)
changed = changed or new_changed
statement.false_statements = []
@ -3091,18 +3131,25 @@ class ByteCodeDecompiler(VerboseOutput):
new_statements.append(statement)
new_statements.extend(new_false_statements)
else:
statement.true_statements, new_changed = update_ifs(statement.true_statements, in_loop)
statement.true_statements, new_changed = update_ifs(statement.true_statements)
changed = changed or new_changed
statement.false_statements, new_changed = update_ifs(statement.false_statements, in_loop)
statement.false_statements, new_changed = update_ifs(statement.false_statements)
changed = changed or new_changed
new_statements.append(statement)
elif isinstance(statement, DoWhileStatement):
# Need to recursively walk through and perform stuff on the body of this.
statement.body, new_changed = update_ifs(statement.body, in_loop=True)
statement.body, new_changed = update_ifs(statement.body)
changed = changed or new_changed
new_statements.append(statement)
elif isinstance(statement, SwitchStatement):
# Need to recursively walk through and perform stuff on the cases of this.
for case in statement.cases:
case.statements, new_changed = update_ifs(case.statements)
changed = changed or new_changed
new_statements.append(statement)
else:
# Don't care about this one, just append it.
@ -3110,7 +3157,7 @@ class ByteCodeDecompiler(VerboseOutput):
return new_statements, changed
return update_ifs(statements, in_loop=False)
return update_ifs(statements)
def __gather_flow(self, parent_conditional: IfExpr, statements: Sequence[Statement]) -> Tuple[IfExpr, Dict[int, IfExpr], List[Tuple[IfExpr, Statement]]]:
flowed_statements: List[Tuple[IfExpr, Statement]] = []
@ -3179,9 +3226,9 @@ class ByteCodeDecompiler(VerboseOutput):
candidates: List[IfStatement] = []
for statement in statements:
# We intentionally ignore while loops here, because we don't collapse gotos across while loop
# boundaries. We will end up running this code again over the contents of any while loop when
# we call __rearrange_compound_ifs on its body.
# We intentionally ignore while loops and switches here, because we don't collapse gotos
# across while loop/switch boundaries. We will end up running this code again over the
# contents of any while loop or switch when we call __rearrange_compound_ifs on its body.
if isinstance(statement, IfStatement):
if statement.true_statements and not statement.false_statements:
candidates.append(statement)
@ -3395,9 +3442,14 @@ class ByteCodeDecompiler(VerboseOutput):
# We hoisted a compound if, so report a change.
changed = True
elif isinstance(statement, DoWhileStatement):
new_body, child_changed = self.__rearrange_compound_ifs(statement.body)
statement.body, child_changed = self.__rearrange_compound_ifs(statement.body)
changed = changed or child_changed
statement.body = new_body
new_statements.append(statement)
i += 1
elif isinstance(statement, SwitchStatement):
for case in statement.cases:
case.statements, new_changed = self.__rearrange_compound_ifs(case.statements)
changed = changed or new_changed
new_statements.append(statement)
i += 1
else:
@ -3415,6 +3467,216 @@ class ByteCodeDecompiler(VerboseOutput):
return self.__hoist_compound_ifs(IsBooleanIf(True), statements, candidates, flow)
def __convert_switches(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
# Convert any cascading if statements comparing the same variable/register against a series
# of constants into switch statements. This relies on a previous run of anohter optimizer
# that ensures that variables show up on the LHS of if statements.
batches: List[Tuple[IfStatement, ...]] = []
processed: Set[Statement] = set()
changed: bool = False
def get_lhs(statement: IfStatement) -> Optional[Expression]:
if not isinstance(statement.cond, TwoParameterIf):
return None
if statement.cond.comp not in {TwoParameterIf.EQUALS, TwoParameterIf.NOT_EQUALS}:
return None
if not isinstance(statement.cond.conditional1, (Variable, Register, Member)):
return None
# We intentionally widen this to allow variables, since there are a lot of places in
# various code that uses variables as constants. We made up the entire language we are
# decompiling into so we can allow this for readability.
if not isinstance(statement.cond.conditional2, (str, bool, float, int, StringConstant, Variable)):
return None
return statement.cond.conditional1
def get_next_candidate(statements: List[Statement], lhs: Expression) -> Optional[IfStatement]:
if len(statements) != 1:
return None
statement = statements[0]
if not isinstance(statement, IfStatement):
return None
new_lhs = get_lhs(statement)
if new_lhs is None:
return None
if str(lhs) != str(new_lhs):
return None
return statement
def find_if_batches(statement: Statement) -> Optional[Statement]:
nonlocal batches
nonlocal processed
if isinstance(statement, IfStatement):
# Don't sum up statements already grabbed by a previous invocation.
if statement in processed:
return statement
# First, detect if this is a two parameter if statement comparing a recognized LHS
# against some constant (number, integer, float, bool, string constant).
lhs = get_lhs(statement)
if lhs is None:
return statement
# Now, collect as many if statement children as we can.
batch: List[IfStatement] = [statement]
while True:
candidate = get_next_candidate(batch[-1].true_statements, lhs)
if candidate:
batch.append(candidate)
continue
candidate = get_next_candidate(batch[-1].false_statements, lhs)
if candidate:
batch.append(candidate)
continue
break
if len(batch) > 1:
# Found one!
batches.append(tuple(batch))
processed.update(batch)
return statement
def replace_if_with_switch(statement: Statement) -> Optional[Statement]:
nonlocal changed
nonlocal batches
nonlocal processed
if isinstance(statement, IfStatement):
if statement not in processed:
return statement
# This is an if statement we should try to handle.
our_batches = [b for b in batches if b[0] is statement]
batches = [b for b in batches if b[0] is not statement]
if len(our_batches) != 1:
raise Exception("Logic error!")
batch = set(our_batches[0])
new_batch: List[IfStatement] = []
processed.difference_update(batch)
# First, make sure all the if entries are an equals.
def swap_not_equals(statement: Statement) -> Optional[Statement]:
nonlocal batch
nonlocal new_batch
if isinstance(statement, IfStatement):
if statement not in batch:
return statement
# We know this is correct, it wouldn't be in our list otherwise.
cond = cast(TwoParameterIf, statement.cond)
# If its already switched, leave it alone.
if cond.comp == TwoParameterIf.EQUALS:
new_batch.append(statement)
return statement
new_statement = IfStatement(
statement.cond.invert(),
statement.false_statements,
statement.true_statements,
)
new_batch.append(new_statement)
return new_statement
return statement
changed_statements = self.__walk([statement], swap_not_equals)
if len(changed_statements) != 1:
raise Exception("Logic error!")
changed_statement = changed_statements[0]
if not isinstance(changed_statement, IfStatement):
raise Exception("Logic error!")
# Now, make sure this is actually a switch pattern.
cur_statement: Statement = changed_statement
cases: List[SwitchCase] = []
for i in range(len(new_batch)):
if new_batch[i] is not cur_statement:
raise Exception("Logic error!")
if i < (len(new_batch) - 1):
# We dont check for the final case, since this will include
# the 'default' case in the else body.
if not isinstance(cur_statement, IfStatement):
# This isn't even an if statement. We should never hit
# this but the type checker wants to be happy.
return statement
if len(cur_statement.false_statements) != 1:
# This can't be a switch, it needs to be an embedded if.
return statement
if cur_statement.false_statements[0] is not new_batch[i + 1]:
# This can't be a switch, it needs to contain the nested
# if we identified.
return statement
cond = new_batch[i].cond
if not isinstance(cond, TwoParameterIf):
# This is purely to make the type checker happy, it should
# never be hit.
return statement
cases.append(
SwitchCase(
cond.conditional2,
[
*new_batch[i].true_statements,
BreakStatement(),
],
)
)
# This is still in the running
cur_statement = cur_statement.false_statements[0]
else:
cond = new_batch[i].cond
if not isinstance(cond, TwoParameterIf):
# This is purely to make the type checker happy, it should
# never be hit.
return statement
cases.append(
SwitchCase(
cond.conditional2,
[
*new_batch[i].true_statements,
BreakStatement(),
],
)
)
if new_batch[i].false_statements:
# Only append a default case if there are any statements
# in the last false body.
cases.append(
SwitchCase(
None,
[
*new_batch[i].false_statements,
BreakStatement(),
],
)
)
cond = new_batch[0].cond
if not isinstance(cond, TwoParameterIf):
# This is purely to make the type checker happy, it should
# never be hit.
return statement
changed = True
return SwitchStatement(
cond.conditional1,
cases,
)
return statement
# First, identify any if statements that should be included in a batch.
self.__walk(statements, find_if_batches)
# Now, replace them with switch statements.
statements = self.__walk(statements, replace_if_with_switch)
return statements, changed
def _optimize_code(self, statements: Sequence[Statement]) -> List[Statement]:
statements = list(statements)
@ -3434,6 +3696,7 @@ class ByteCodeDecompiler(VerboseOutput):
self.__drop_unneeded_else,
self.__swap_ugly_ifexprs,
self.__rearrange_compound_ifs,
self.__convert_switches,
]:
statements, changed = func(statements)
any_changed = any_changed or changed
@ -3444,10 +3707,6 @@ class ByteCodeDecompiler(VerboseOutput):
else:
self.vprint("Optimizer changed code, running another pass.")
# TODO: We probably want to collapse some really long if chains to switch
# statements or if/elif/else blocks for readability but that is left as a
# future enhancement.
return statements
def __verify_balanced_labels(self, statements: Sequence[Statement]) -> None:

View File

@ -61,6 +61,8 @@ from .statement import (
DebugTraceStatement,
NopStatement,
IfStatement,
SwitchStatement,
SwitchCase,
DoWhileStatement,
WhileStatement,
ForStatement,
@ -137,6 +139,8 @@ __all__ = [
'DebugTraceStatement',
'NopStatement',
'IfStatement',
'SwitchStatement',
'SwitchCase',
'DoWhileStatement',
'WhileStatement',
'ForStatement',

View File

@ -1097,3 +1097,74 @@ class WhileStatement(DoWhileStatement):
*entries,
f"{prefix}}}",
]
class SwitchCase:
def __init__(self, const: Any, statements: Sequence[Statement]) -> None:
self.const = const
self.statements = list(statements)
def __repr__(self) -> str:
entries: List[str] = []
for statement in self.statements:
entries.extend([f" {s}" for s in str(statement).split(os.linesep)])
if self.const is not None:
const = value_ref(self.const, "")
return os.linesep.join([
f"case {const}:",
os.linesep.join(entries),
])
else:
return os.linesep.join([
"default:",
os.linesep.join(entries),
])
def render(self, prefix: str) -> List[str]:
entries: List[str] = []
for statement in self.statements:
entries.extend(statement.render(prefix=prefix + " "))
if self.const is not None:
const = value_ref(self.const, prefix)
return [
f"{prefix}case {const}:",
*entries,
]
else:
return [
f"{prefix}default:",
*entries,
]
class SwitchStatement(Statement):
def __init__(self, check_variable: Any, cases: Sequence[SwitchCase]) -> None:
self.check_variable = check_variable
self.cases = list(cases)
def __repr__(self) -> str:
cases: List[str] = []
for case in self.cases:
cases.extend([f" {s}" for s in str(case).split(os.linesep)])
check = object_ref(self.check_variable, "")
return os.linesep.join([
f"switch ({check}) {{",
os.linesep.join(cases),
"}"
])
def render(self, prefix: str) -> List[str]:
cases: List[str] = []
for case in self.cases:
cases.extend(case.render(prefix=prefix + " "))
check = object_ref(self.check_variable, prefix)
return [
f"{prefix}switch ({check})",
f"{prefix}{{",
*cases,
f"{prefix}}}"
]

View File

@ -747,21 +747,20 @@ class TestAFPDecompile(ExtendedTestCase):
AP2Action(113, AP2Action.RETURN),
])
statements = self.__call_decompile(bytecode)
# TODO: This should be optimized as an if/elseif/else chunk without so much indentation.
self.assertEqual(self.__equiv(statements), [
f"if (registers[0] != 1) {OPEN_BRACKET}{os.linesep}"
f" if (registers[0] != 2) {OPEN_BRACKET}{os.linesep}"
f" if (registers[0] != 3) {OPEN_BRACKET}{os.linesep}"
f" tempvar_0 = 'd'{os.linesep}"
f" {CLOSE_BRACKET} else {OPEN_BRACKET}{os.linesep}"
f" tempvar_0 = 'c'{os.linesep}"
f" {CLOSE_BRACKET}{os.linesep}"
f" {CLOSE_BRACKET} else {OPEN_BRACKET}{os.linesep}"
f"switch (registers[0]) {OPEN_BRACKET}{os.linesep}"
f" case 1:{os.linesep}"
f" tempvar_0 = 'a'{os.linesep}"
f" break{os.linesep}"
f" case 2:{os.linesep}"
f" tempvar_0 = 'b'{os.linesep}"
f" {CLOSE_BRACKET}{os.linesep}"
f"{CLOSE_BRACKET} else {OPEN_BRACKET}{os.linesep}"
f" tempvar_0 = 'a'{os.linesep}"
f" break{os.linesep}"
f" case 3:{os.linesep}"
f" tempvar_0 = 'c'{os.linesep}"
f" break{os.linesep}"
f" default:{os.linesep}"
f" tempvar_0 = 'd'{os.linesep}"
f" break{os.linesep}"
"}",
"return tempvar_0"
])