1
0
mirror of synced 2024-11-24 06:20:12 +01:00

Add some tests for for/while loops in anticipation of optimizing output.

This commit is contained in:
Jennifer Taylor 2021-05-05 00:53:02 +00:00
parent 83f4f43e9e
commit e17f7c2bdb
2 changed files with 197 additions and 15 deletions

View File

@ -795,8 +795,7 @@ class DoWhileStatement(Statement):
return os.linesep.join([
"do {",
os.linesep.join(entries),
"}",
"while(True);"
"} while (True)"
])
def render(self, prefix: str) -> List[str]:
@ -809,7 +808,7 @@ class DoWhileStatement(Statement):
f"{prefix}{{",
*entries,
f"{prefix}}}",
f"{prefix}while(True);",
f"{prefix}while (True);",
]
@ -3099,7 +3098,7 @@ class ByteCodeDecompiler(VerboseOutput):
def __eliminate_useless_returns(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
# Go through and find returns that are on the "last" line. Basically, any
# return statement where the next return statement is another return statement
# return statement where the next statement is another return statement
# or the end of a function.
def find_returns(statements: Sequence[Statement], next_statement: Statement) -> Set[NullReturnStatement]:
returns: Set[NullReturnStatement] = set()
@ -3218,20 +3217,49 @@ class ByteCodeDecompiler(VerboseOutput):
return self.__walk(statements, remove_label), changed
def __eliminate_useless_continues(self, statements: Sequence[Statement]) -> Tuple[List[Statement], bool]:
# Go through and find continue statements on the last line of a do-while.
changed: bool = False
# Go through and find continues that are on the "last" line of a while. Basically, any
# continue statement where the next statement is another continue statement or the end
# of a loop.
def find_continues(statements: Sequence[Statement], next_statement: Statement) -> Set[ContinueStatement]:
continues: Set[ContinueStatement] = set()
def remove_continue(statement: Statement) -> Optional[Statement]:
nonlocal changed
for i in range(len(statements)):
cur_statement = statements[i]
next_statement = statements[i + 1] if (i < len(statements) - 1) else next_statement
if (
isinstance(cur_statement, ContinueStatement) and
isinstance(next_statement, ContinueStatement)
):
continues.add(cur_statement)
if isinstance(statement, DoWhileStatement):
if statement.body and isinstance(statement.body[-1], ContinueStatement):
changed = True
statement.body.pop()
elif isinstance(cur_statement, DoWhileStatement):
# Clever hack, where we pretend the next value after the loop is a continue,
# because hitting the bottom of a loop is actually a continue.
continues.update(find_continues(cur_statement.body, ContinueStatement()))
elif isinstance(cur_statement, IfStatement):
continues.update(find_continues(cur_statement.true_statements, next_statement))
continues.update(find_continues(cur_statement.false_statements, next_statement))
return continues
# Instead of an empty next statement, make up a return because that's what
# falling off the end of execution means.
continues = find_continues(statements, NullReturnStatement())
updated: bool = False
def remove_continues(statement: Statement) -> Statement:
nonlocal updated
for removable in continues:
if removable is statement:
updated = True
return None
return statement
statements = self.__walk(statements, remove_continue)
return statements, changed
statements = self.__walk(statements, remove_continues)
return statements, updated
def __swap_empty_ifs(self, statements: Sequence[Statement]) -> List[Statement]:
# Go through and find continue statements on the last line of a do-while.

View File

@ -3,7 +3,7 @@ import unittest
from typing import Dict, List, Sequence, Tuple, Union
from bemani.tests.helpers import ExtendedTestCase
from bemani.format.afp.types.ap2 import AP2Action, IfAction, JumpAction, PushAction, Register
from bemani.format.afp.types.ap2 import AP2Action, IfAction, JumpAction, PushAction, AddNumVariableAction, Register
from bemani.format.afp.decompile import BitVector, ByteCode, ByteCodeChunk, ControlFlow, ByteCodeDecompiler, Statement
@ -779,3 +779,157 @@ class TestAFPDecompile(ExtendedTestCase):
"builtin_StartPlaying()",
"goto label_4",
])
def test_basic_while(self) -> None:
# A basic while statement.
bytecode = self.__make_bytecode([
# Define exit condition variable.
PushAction(100, ["finished", False]),
AP2Action(101, AP2Action.DEFINE_LOCAL),
# Check exit condition.
PushAction(102, ["finished"]),
AP2Action(103, AP2Action.GET_VARIABLE),
IfAction(104, IfAction.IS_TRUE, 107),
# Loop code.
AP2Action(105, AP2Action.NEXT_FRAME),
# Loop finished jump back to beginning.
JumpAction(106, 102),
# End of loop.
AP2Action(107, AP2Action.END),
])
statements = self.__call_decompile(bytecode)
# TODO: This should be optimized as a while statement.
self.assertEqual(self.__equiv(statements), [
"local finished = False",
"do {\n"
" if (finished) {\n"
" break\n"
" }\n"
" builtin_GotoNextFrame()\n"
"} while (True)"
])
def test_advanced_while(self) -> None:
# A basic while statement.
bytecode = self.__make_bytecode([
# Define exit condition variable.
PushAction(100, ["finished", False]),
AP2Action(101, AP2Action.DEFINE_LOCAL),
# Check exit condition.
PushAction(102, ["finished"]),
AP2Action(103, AP2Action.GET_VARIABLE),
IfAction(104, IfAction.IS_TRUE, 112),
# Loop code with a continue statement.
PushAction(105, ["some_condition"]),
AP2Action(106, AP2Action.GET_VARIABLE),
IfAction(107, IfAction.IS_FALSE, 110),
AP2Action(108, AP2Action.NEXT_FRAME),
# Continue statement.
JumpAction(109, 102),
# Exit early.
AP2Action(110, AP2Action.STOP),
# Break statement.
JumpAction(111, 112),
# End of loop.
AP2Action(112, AP2Action.END),
])
statements = self.__call_decompile(bytecode)
# TODO: This should be optimized as a while statement.
self.assertEqual(self.__equiv(statements), [
"local finished = False",
"do {\n"
" if (finished) {\n"
" break\n"
" }\n"
" if (not some_condition) {\n"
" builtin_StopPlaying()\n"
" break\n"
" } else {\n"
" builtin_GotoNextFrame()\n"
" }\n"
"} while (True)"
])
def test_basic_for(self) -> None:
# A basic for statement.
bytecode = self.__make_bytecode([
# Define exit condition variable.
PushAction(100, ["i", 0]),
AP2Action(101, AP2Action.DEFINE_LOCAL),
# Check exit condition.
PushAction(102, [10, "i"]),
AP2Action(103, AP2Action.GET_VARIABLE),
IfAction(104, IfAction.LT_EQUALS, 109),
# Loop code.
AP2Action(105, AP2Action.NEXT_FRAME),
# Increment, also the continue point.
PushAction(106, ["i"]),
AddNumVariableAction(107, 1),
# Loop finished jump back to beginning.
JumpAction(108, 102),
# End of loop.
AP2Action(109, AP2Action.END),
])
statements = self.__call_decompile(bytecode)
# TODO: This should be optimized as a for statement.
self.assertEqual(self.__equiv(statements), [
"local i = 0",
"do {\n"
" if (10 <= i) {\n"
" break\n"
" }\n"
" builtin_GotoNextFrame()\n"
" i = i + 1\n"
"} while (True)"
])
def test_advanced_for(self) -> None:
# A basic for statement.
bytecode = self.__make_bytecode([
# Define exit condition variable.
PushAction(100, ["i", 0]),
AP2Action(101, AP2Action.DEFINE_LOCAL),
# Check exit condition.
PushAction(102, [10, "i"]),
AP2Action(103, AP2Action.GET_VARIABLE),
IfAction(104, IfAction.LT_EQUALS, 115),
# Loop code with a continue statement.
PushAction(105, ["some_condition"]),
AP2Action(106, AP2Action.GET_VARIABLE),
IfAction(107, IfAction.IS_FALSE, 110),
AP2Action(108, AP2Action.NEXT_FRAME),
# Continue statement.
JumpAction(109, 112),
# Exit early.
AP2Action(110, AP2Action.STOP),
# Break statement.
JumpAction(111, 115),
# Increment, also the continue point.
PushAction(112, ["i"]),
AddNumVariableAction(113, 1),
# Loop finished jump back to beginning.
JumpAction(114, 102),
# End of loop.
AP2Action(115, AP2Action.END),
])
statements = self.__call_decompile(bytecode)
# TODO: This should be optimized as a for statement.
self.assertEqual(self.__equiv(statements), [
"local i = 0",
"do {\n"
" if (10 <= i) {\n"
" break\n"
" }\n"
" if (not some_condition) {\n"
" builtin_StopPlaying()\n"
" break\n"
" } else {\n"
" builtin_GotoNextFrame()\n"
" i = i + 1\n"
" }\n"
"} while (True)"
])