1
0
mirror of synced 2025-02-20 20:50:59 +01:00

Fix if processing by getting rid of negation early optimization. There was a path that

was negating but not being checked when doing the final render, leaving some if statements
inverted. Moved the optimization to a final pass for more generalized correctness.
This commit is contained in:
Jennifer Taylor 2021-04-26 02:50:19 +00:00
parent 2d4f6cadfb
commit b7523b5521

View File

@ -592,7 +592,8 @@ class SetLocalStatement(Statement):
class IfExpr(ConvertedAction):
# This is just for typing.
pass
def invert(self) -> "IfExpr":
raise NotImplementedError("Not implemented!")
class IsUndefinedIf(IfExpr):
@ -600,6 +601,9 @@ class IsUndefinedIf(IfExpr):
self.conditional = conditional
self.negate = negate
def invert(self) -> "IsUndefinedIf":
return IsUndefinedIf(self.conditional, not self.negate)
def __repr__(self) -> str:
val = value_ref(self.conditional, "", parens=True)
if self.negate:
@ -613,60 +617,67 @@ class IsBooleanIf(IfExpr):
self.conditional = conditional
self.negate = negate
def invert(self) -> "IsBooleanIf":
return IsBooleanIf(self.conditional, not self.negate)
def __repr__(self) -> str:
val = value_ref(self.conditional, "", parens=True)
if self.negate:
return f"if ({val} is False)"
return f"if (not {val})"
else:
return f"if ({val} is True)"
return f"if ({val})"
class IsEqualIf(IfExpr):
def __init__(self, conditional1: Any, conditional2: Any, negate: bool) -> None:
class TwoParameterIf(IfExpr):
EQUALS = "=="
NOT_EQUALS = "!="
LT = "<"
GT = ">"
LT_EQUALS = "<="
GT_EQUALS = ">="
STRICT_EQUALS = "==="
STRICT_NOT_EQUALS = "!=="
def __init__(self, conditional1: Any, comp: str, conditional2: Any) -> None:
if comp not in {
self.EQUALS,
self.NOT_EQUALS,
self.LT,
self.GT,
self.LT_EQUALS,
self.GT_EQUALS,
self.STRICT_EQUALS,
self.STRICT_NOT_EQUALS,
}:
raise Exception(f"Invalid comparision {comp}!")
self.conditional1 = conditional1
self.comp = comp
self.conditional2 = conditional2
self.negate = negate
def invert(self) -> "TwoParameterIf":
if self.comp == self.EQUALS:
return TwoParameterIf(self.conditional1, self.NOT_EQUALS, self.conditional2)
if self.comp == self.NOT_EQUALS:
return TwoParameterIf(self.conditional1, self.EQUALS, self.conditional2)
if self.comp == self.LT:
return TwoParameterIf(self.conditional1, self.GT_EQUALS, self.conditional2)
if self.comp == self.GT:
return TwoParameterIf(self.conditional1, self.LT_EQUALS, self.conditional2)
if self.comp == self.LT_EQUALS:
return TwoParameterIf(self.conditional1, self.GT, self.conditional2)
if self.comp == self.GT_EQUALS:
return TwoParameterIf(self.conditional1, self.LT, self.conditional2)
if self.comp == self.STRICT_EQUALS:
return TwoParameterIf(self.conditional1, self.STRICT_NOT_EQUALS, self.conditional2)
if self.comp == self.STRICT_NOT_EQUALS:
return TwoParameterIf(self.conditional1, self.STRICT_EQUALS, self.conditional2)
raise Exception(f"Cannot invert {self.comp}!")
def __repr__(self) -> str:
val1 = value_ref(self.conditional1, "", parens=True)
val2 = value_ref(self.conditional2, "", parens=True)
return f"if ({val1} {'!=' if self.negate else '=='} {val2})"
class IsStrictEqualIf(IfExpr):
def __init__(self, conditional1: Any, conditional2: Any, negate: bool) -> None:
self.conditional1 = conditional1
self.conditional2 = conditional2
self.negate = negate
def __repr__(self) -> str:
val1 = value_ref(self.conditional1, "", parens=True)
val2 = value_ref(self.conditional2, "", parens=True)
return f"if ({val1} {'!==' if self.negate else '==='} {val2})"
class MagnitudeIf(IfExpr):
def __init__(self, conditional1: Any, conditional2: Any, negate: bool) -> None:
self.conditional1 = conditional1
self.conditional2 = conditional2
self.negate = negate
def __repr__(self) -> str:
val1 = value_ref(self.conditional1, "", parens=True)
val2 = value_ref(self.conditional2, "", parens=True)
return f"if ({val1} {'<' if self.negate else '>'} {val2})"
class MagnitudeEqualIf(IfExpr):
def __init__(self, conditional1: Any, conditional2: Any, negate: bool) -> None:
self.conditional1 = conditional1
self.conditional2 = conditional2
self.negate = negate
def __repr__(self) -> str:
val1 = value_ref(self.conditional1, "", parens=True)
val2 = value_ref(self.conditional2, "", parens=True)
return f"if ({val1} {'<=' if self.negate else '>='} {val2})"
return f"if ({val1} {self.comp} {val2})"
class IfStatement(Statement):
@ -759,11 +770,10 @@ class DoWhileStatement(Statement):
class IntermediateIf(ConvertedAction):
def __init__(self, parent_action: IfAction, true_statements: Sequence[Statement], false_statements: Sequence[Statement], negate: bool) -> None:
def __init__(self, parent_action: IfAction, true_statements: Sequence[Statement], false_statements: Sequence[Statement]) -> None:
self.parent_action = parent_action
self.true_statements = list(true_statements)
self.false_statements = list(false_statements)
self.negate = negate
def __repr__(self) -> str:
true_entries: List[str] = []
@ -776,7 +786,7 @@ class IntermediateIf(ConvertedAction):
if self.false_statements:
return os.linesep.join([
f"if <{'!' if self.negate else ''}{self.parent_action}> {{",
f"if <{self.parent_action}> {{",
os.linesep.join(true_entries),
"} else {",
os.linesep.join(false_entries),
@ -784,7 +794,7 @@ class IntermediateIf(ConvertedAction):
])
else:
return os.linesep.join([
f"if <{'!' if self.negate else ''}{self.parent_action}> {{",
f"if <{self.parent_action}> {{",
os.linesep.join(true_entries),
"}"
])
@ -855,7 +865,7 @@ class Loop:
class IfBody:
def __init__(self, id: int, true_chunks: Sequence[ArbitraryCodeChunk], false_chunks: Sequence[ArbitraryCodeChunk], next_chunk: Optional[int], previous_chunk: int, negate: bool) -> None:
def __init__(self, id: int, true_chunks: Sequence[ArbitraryCodeChunk], false_chunks: Sequence[ArbitraryCodeChunk], next_chunk: Optional[int], previous_chunk: int) -> None:
# The ID in this case is what the previous block points at. It does not
# have any bearing on the ID of the true and false chunks.
self.id = id
@ -866,7 +876,6 @@ class IfBody:
self.next_chunks: List[int] = [next_chunk] if next_chunk is not None else []
self.true_chunks = list(true_chunks)
self.false_chunks = list(false_chunks)
self.negate = negate
def __repr__(self) -> str:
true_entries: List[str] = []
@ -879,7 +888,6 @@ class IfBody:
return (
f"IfBody({os.linesep}" +
f" Negated: {self.negate}{os.linesep}" +
f" ID: {self.id}{os.linesep}" +
(f" Previous Chunks: {', '.join(str(n) for n in self.previous_chunks)}{os.linesep}" if self.previous_chunks else f" Start Chunk{os.linesep}") +
f" True Chunks:{os.linesep}" +
@ -1460,18 +1468,10 @@ class ByteCodeDecompiler(VerboseOutput):
# a later if logic step.
continue
if true_action is None and false_action is not None:
true_action = false_action
false_action = None
negate = True
else:
negate = False
chunk.actions[-1] = IntermediateIf(
cast(IfAction, last_action),
[true_action],
[false_action] if false_action else [],
negate=negate,
)
# At this point, all chunks in our list should point only to other chunks in our list.
@ -1668,18 +1668,10 @@ class ByteCodeDecompiler(VerboseOutput):
chunk.next_chunks = [c for c in chunk.next_chunks if c != false_jump_point]
if true_action or false_action:
if (not true_action) and false_action:
true_action = false_action
false_action = None
negate = True
else:
negate = False
chunk.actions[-1] = IntermediateIf(
cast(IfAction, last_action),
[true_action],
[false_action] if false_action else [],
negate=negate,
)
def __find_shallowest_successor(self, start_chunk: int, chunks_by_id: Dict[int, ArbitraryCodeChunk]) -> Optional[int]:
@ -1884,13 +1876,6 @@ class ByteCodeDecompiler(VerboseOutput):
# We should have at least one!
raise Exception("Logic error, if statement has no code for if or else!")
if false_chunks and (not true_chunks):
negate = True
true_chunks = false_chunks
false_chunks = []
else:
negate = False
# Lets use a brand new ID here for easier traversal and so we don't accidentally
# reuse the ID of one of our parents if a jump point is a goto.
if_id = self.__goto_body_id
@ -1900,7 +1885,7 @@ class ByteCodeDecompiler(VerboseOutput):
# none of the chunks in the true or false bodies of the if, so we add it back to the graph
# in the form of an IfBody.
self.vprint(f"Created new IfBody for chunk {cur_id} to point at, ending at {if_id}")
chunks_by_id[if_id] = IfBody(if_id, true_chunks, false_chunks, if_end, cur_id, negate)
chunks_by_id[if_id] = IfBody(if_id, true_chunks, false_chunks, if_end, cur_id)
chunks_by_id[cur_id].next_chunks = [if_id]
if if_end is not None:
@ -1974,31 +1959,39 @@ class ByteCodeDecompiler(VerboseOutput):
borrows.append(borrow)
return borrow
def make_if_expr(action: IfAction, negate: bool) -> IfExpr:
def make_if_expr(action: IfAction) -> IfExpr:
if action.comparison in [IfAction.IS_UNDEFINED, IfAction.IS_NOT_UNDEFINED]:
conditional = get_stack()
return IsUndefinedIf(conditional, negate=negate != (action.comparison == IfAction.IS_UNDEFINED))
if action.comparison in [IfAction.IS_TRUE, IfAction.IS_FALSE]:
return IsUndefinedIf(conditional, negate=(action.comparison != IfAction.IS_UNDEFINED))
elif action.comparison in [IfAction.IS_TRUE, IfAction.IS_FALSE]:
conditional = get_stack()
return IsBooleanIf(conditional, negate=negate != (action.comparison == IfAction.IS_FALSE))
if action.comparison in [IfAction.EQUALS, IfAction.NOT_EQUALS]:
return IsBooleanIf(conditional, negate=(action.comparison != IfAction.IS_TRUE))
elif action.comparison in [
IfAction.EQUALS,
IfAction.NOT_EQUALS,
IfAction.STRICT_EQUALS,
IfAction.STRICT_NOT_EQUALS,
IfAction.LT,
IfAction.GT,
IfAction.LT_EQUALS,
IfAction.GT_EQUALS
]:
conditional2 = get_stack()
conditional1 = get_stack()
return IsEqualIf(conditional1, conditional2, negate=negate != (action.comparison == IfAction.NOT_EQUALS))
if action.comparison in [IfAction.STRICT_EQUALS, IfAction.STRICT_NOT_EQUALS]:
conditional2 = get_stack()
conditional1 = get_stack()
return IsStrictEqualIf(conditional1, conditional2, negate=negate != (action.comparison == IfAction.STRICT_NOT_EQUALS))
if action.comparison in [IfAction.LT, IfAction.GT]:
conditional2 = get_stack()
conditional1 = get_stack()
return MagnitudeIf(conditional1, conditional2, negate=negate != (action.comparison == IfAction.LT))
if action.comparison in [IfAction.LT_EQUALS, IfAction.GT_EQUALS]:
conditional2 = get_stack()
conditional1 = get_stack()
return MagnitudeEqualIf(conditional1, conditional2, negate=negate != (action.comparison == IfAction.LT_EQUALS))
comp = {
IfAction.EQUALS: TwoParameterIf.EQUALS,
IfAction.NOT_EQUALS: TwoParameterIf.NOT_EQUALS,
IfAction.STRICT_EQUALS: TwoParameterIf.STRICT_EQUALS,
IfAction.STRICT_NOT_EQUALS: TwoParameterIf.STRICT_NOT_EQUALS,
IfAction.LT: TwoParameterIf.LT,
IfAction.GT: TwoParameterIf.GT,
IfAction.LT_EQUALS: TwoParameterIf.LT_EQUALS,
IfAction.GT_EQUALS: TwoParameterIf.GT_EQUALS,
}[action.comparison]
raise Exception(f"TODO: {action}")
return TwoParameterIf(conditional1, comp, conditional2)
else:
raise Exception(f"TODO: {action}")
# TODO: Everywhere that we assert on a type needs to be updated to check for a borrow, and if the borrow
# exists we should instead set the borrow to check for that type.
@ -2075,7 +2068,7 @@ class ByteCodeDecompiler(VerboseOutput):
continue
if isinstance(action, IfAction):
chunk.actions[i] = make_if_expr(action, False)
chunk.actions[i] = make_if_expr(action)
continue
if isinstance(action, AddNumVariableAction):
@ -2445,7 +2438,7 @@ class ByteCodeDecompiler(VerboseOutput):
if isinstance(action, IntermediateIf):
# A partially-converted if from loop detection. Let's hoist it out properly.
chunk.actions[i] = IfStatement(
make_if_expr(action.parent_action, action.negate),
make_if_expr(action.parent_action),
action.true_statements,
action.false_statements,
)
@ -2902,6 +2895,21 @@ class ByteCodeDecompiler(VerboseOutput):
return statements[:-1]
return statements
def __swap_empty_ifs(self, statements: Sequence[Statement]) -> List[Statement]:
# Go through and find continue statements on the last line of a do-while.
def swap_empty_ifs(statement: Statement) -> Optional[Statement]:
if isinstance(statement, IfStatement):
if statement.false_statements and (not statement.true_statements):
# Swap this, invert the conditional
return IfStatement(
statement.cond.invert(),
statement.false_statements,
statement.true_statements,
)
return statement
return self.__walk(statements, swap_empty_ifs)
def __pretty_print(self, statements: Sequence[Statement], prefix: str = "") -> str:
output: List[str] = []
@ -2951,6 +2959,7 @@ class ByteCodeDecompiler(VerboseOutput):
if not changed1 and not changed2:
break
statements = self.__swap_empty_ifs(statements)
# Finally, let's save the code!
self.__statements = statements