1
0
mirror of synced 2025-02-20 20:50:59 +01:00

Implement recursive if detection.

This commit is contained in:
Jennifer Taylor 2021-04-24 18:02:26 +00:00
parent 6e34d2647e
commit 6aa04b0c1b

View File

@ -101,6 +101,13 @@ class GotoStatement(ConvertedAction):
return f"goto label_{self.location};"
class NullReturnStatement(ConvertedAction):
# A statement which directs the control flow to the end of the code, but
# does not pop the stack to return
def __repr__(self) -> str:
return "return;"
class IntermediateIfStatement(ConvertedAction):
def __init__(self, parent_action: IfAction, true_actions: Sequence[ArbitraryOpcode], false_actions: Sequence[ArbitraryOpcode], negate: bool) -> None:
self.parent_action = parent_action
@ -159,7 +166,7 @@ class ByteCodeChunk:
)
ArbitraryCodeChunk = Union[ByteCodeChunk, "Loop"]
ArbitraryCodeChunk = Union[ByteCodeChunk, "Loop", "IfBody"]
class Loop:
@ -197,6 +204,42 @@ class Loop:
)
class IfBody:
def __init__(self, id: int, true_chunks: Sequence[ArbitraryCodeChunk], false_chunks: Sequence[ArbitraryCodeChunk], next_chunk: Optional[int], previous_chunk: int, negate: bool) -> None:
# The ID is usually the chunk that other chunks point into.
self.id = id
# If bodies are a bit special compared to Loops, we know the previous and next chunks
# for all of them.
self.previous_chunks: List[int] = [previous_chunk]
self.next_chunks: List[int] = [next_chunk] if next_chunk is not None else []
self.true_chunks = list(true_chunks)
self.false_chunks = list(false_chunks)
self.negate = negate
def __repr__(self) -> str:
true_entries: List[str] = []
for chunk in self.true_chunks:
true_entries.extend([f" {s}" for s in str(chunk).split(os.linesep)])
false_entries: List[str] = []
for chunk in self.false_chunks:
false_entries.extend([f" {s}" for s in str(chunk).split(os.linesep)])
return (
f"IfBody({os.linesep}" +
f" Negated: {self.negate}{os.linesep}" +
f" ID: {self.id}{os.linesep}" +
(f" Previous Chunks: {', '.join(str(n) for n in self.previous_chunks)}{os.linesep}" if self.previous_chunks else f" Start Chunk{os.linesep}") +
f" True Chunks:{os.linesep}" +
f"{os.linesep.join(true_entries)}{os.linesep}" +
f" False Chunks:{os.linesep}" +
f"{os.linesep.join(false_entries)}{os.linesep}" +
(f" Next Chunks: {', '.join(str(n) for n in self.next_chunks)}{os.linesep}" if self.next_chunks else f" End Chunk{os.linesep}") +
")"
)
class BitVector:
def __init__(self, length: int, init: bool = False) -> None:
self.__bits: Dict[int, bool] = {i: init for i in range(length)}
@ -466,7 +509,7 @@ class ByteCodeDecompiler(VerboseOutput):
return (sorted(chunks, key=lambda c: c.id), offset_to_id)
def __get_entry_block(self, chunks: Sequence[ByteCodeChunk]) -> int:
def __get_entry_block(self, chunks: Sequence[Union[ByteCodeChunk, Loop]]) -> int:
start_id: int = -1
for chunk in chunks:
if not chunk.previous_chunks:
@ -548,10 +591,6 @@ class ByteCodeDecompiler(VerboseOutput):
last_action = chunk.actions[-1]
if isinstance(last_action, AP2Action):
if last_action.opcode in [AP2Action.THROW, AP2Action.RETURN]:
# Ignore these for now, we'll fix these up in a later stage.
continue
if last_action.opcode == AP2Action.JUMP:
# This is either an unconditional break/continue or an
# internal jump.
@ -568,8 +607,12 @@ class ByteCodeDecompiler(VerboseOutput):
chunk.actions[-1] = ContinueStatement()
chunk.next_chunks = []
elif next_chunk not in internal_jump_points:
self.vprint("Converting jump to external point into goto statement.")
chunk.actions[-1] = GotoStatement(next_chunk)
if next_chunk == offset_map[self.bytecode.end_offset]:
self.vprint("Converting jump to external point into return statement.")
chunk.actions[-1] = NullReturnStatement()
else:
self.vprint("Converting jump to external point into goto statement.")
chunk.actions[-1] = GotoStatement(next_chunk)
chunk.next_chunks = []
continue
@ -592,8 +635,12 @@ class ByteCodeDecompiler(VerboseOutput):
true_action = ContinueStatement()
chunk.next_chunks = [n for n in chunk.next_chunks if n != true_jump_point]
elif true_jump_point not in internal_jump_points:
self.vprint("Converting jump if true to external point into goto statement.")
true_action = GotoStatement(true_jump_point)
if true_jump_point == offset_map[self.bytecode.end_offset]:
self.vprint("Converting jump if true to external point into return statement.")
true_action = NullReturnStatement()
else:
self.vprint("Converting jump if true to external point into goto statement.")
true_action = GotoStatement(true_jump_point)
chunk.next_chunks = [n for n in chunk.next_chunks if n != true_jump_point]
false_action: Optional[ConvertedAction] = None
@ -606,8 +653,12 @@ class ByteCodeDecompiler(VerboseOutput):
false_action = ContinueStatement()
chunk.next_chunks = [n for n in chunk.next_chunks if n != false_jump_point]
elif false_jump_point not in internal_jump_points:
self.vprint("Converting jump if false to external point into goto statement.")
false_action = GotoStatement(false_jump_point)
if false_jump_point == offset_map[self.bytecode.end_offset]:
self.vprint("Converting jump if false to external point into return statement.")
false_action = NullReturnStatement()
else:
self.vprint("Converting jump if false to external point into goto statement.")
false_action = GotoStatement(false_jump_point)
chunk.next_chunks = [n for n in chunk.next_chunks if n != false_jump_point]
if true_action is None and false_action is not None:
@ -698,7 +749,9 @@ class ByteCodeDecompiler(VerboseOutput):
new_loop = Loop(header, [chunks_by_id[i] for i in blocks])
# Eliminate jumps that are to the beginning/end of the loop to
# make if statement detection later on easier.
# make if statement detection later on easier. This also breaks
# the graph at any spot where we successfully converted a jump
# to a break/continue/goto.
new_loop = self.__analyze_loop_jumps(new_loop, offset_map)
chunks_by_id[header] = new_loop
@ -731,8 +784,211 @@ class ByteCodeDecompiler(VerboseOutput):
return [chunks_by_id[i] for i in chunks_by_id]
def __separate_ifs(self, chunks: Sequence[Union[ByteCodeChunk, Loop]], offset_map: Dict[int, int]) -> List[ArbitraryCodeChunk]:
return [c for c in chunks]
def __break_graph(self, chunks: Sequence[Union[ByteCodeChunk, Loop]], offset_map: Dict[int, int]) -> None:
for chunk in chunks:
if chunk.id == offset_map[self.bytecode.end_offset]:
# Don't examine the sentinel we keep around as a jump point for returns.
continue
if isinstance(chunk, Loop):
self.vprint(f"Entering into loop {chunk.id} to break graph...")
# At this point, we know chunk.chunks is a Union[ByteCodeChunk, Loop] because we haven't run
# any if detection yet.
self.__break_graph(cast(List[Union[ByteCodeChunk, Loop]], chunk.chunks), offset_map)
else:
# Examine the last instruction.
last_action = chunk.actions[-1]
if isinstance(last_action, AP2Action):
if last_action.opcode in [AP2Action.THROW, AP2Action.RETURN]:
# The last action already dictates what we should do here. Break
# the chain at this point.
self.vprint(f"Breaking chain on {chunk.id} because it is a {last_action}.")
chunk.next_chunks = []
elif len(chunk.next_chunks) == 1 and chunk.next_chunks[0] == offset_map[self.bytecode.end_offset]:
# The jump point for this is the end of the function. If it is a jump,
# then we should replace it with a return. If it is not a jump, we should
# add a return.
if last_action.opcode == AP2Action.JUMP:
self.vprint(f"Converting jump to end of code in {chunk.id} into a null return.")
chunk.actions[-1] = NullReturnStatement()
else:
self.vprint(f"Converting fall-through to end of code in {chunk.id} into a null return.")
chunk.actions.append(NullReturnStatement())
chunk.next_chunks = []
def __find_shallowest_successor(self, start_chunk: int, chunks_by_id: Dict[int, ArbitraryCodeChunk]) -> Optional[int]:
if len(chunks_by_id[start_chunk].next_chunks) != 2:
# We don't care about this, the successor is the next chunk!
raise Exception("Logic error!")
left, right = chunks_by_id[start_chunk].next_chunks
visited: Set[int] = set()
# First, let's find all the successors to the left side.
candidates: List[int] = [left]
while candidates:
for candidate in candidates:
visited.add(candidate)
new_candidates = []
for candidate in candidates:
# We can avoid re-traversing what we've already traversed, as we only want to color
# in the part of the tree that we're interested in.
new_candidates.extend([c for c in chunks_by_id[candidate].next_chunks if c not in visited])
candidates = new_candidates
# Now, lets do the same with the right, and the first one we encounter that's visited is our guy.
candidates = [right]
while candidates:
for candidate in candidates:
if candidate in visited:
return candidate
new_candidates = []
for candidate in candidates:
# We can't take the same shortcut here as above, as we are trying to ask the question
# of what's the shallowest successor, not color them in.
new_candidates.extend(chunks_by_id[candidate].next_chunks)
candidates = new_candidates
# If we didn't find a successor, that means one of the control paths leads to end of execution.
return None
def __gather_chunks(self, start_chunk: int, end_chunk: Optional[int], chunks_by_id: Dict[int, ArbitraryCodeChunk]) -> List[ArbitraryCodeChunk]:
visited: Set[int] = set()
chunks: List[ArbitraryCodeChunk] = []
candidates: List[int] = [start_chunk]
while candidates:
first_candidate = candidates.pop()
if first_candidate in visited:
# We already visited this node.
continue
if end_chunk is None or first_candidate != end_chunk:
chunks.append(chunks_by_id[first_candidate])
visited.add(first_candidate)
candidates.extend(chunks_by_id[first_candidate].next_chunks)
# The chunk list is all chunks that belong in this sequence. Now, kill any pointers to the end chunk.
if end_chunk is not None:
for chunk in chunks:
chunk.next_chunks = [n for n in chunk.next_chunks if n != end_chunk]
return chunks
def __separate_ifs(self, start_id: int, chunks: Sequence[ArbitraryCodeChunk], offset_map: Dict[int, int]) -> List[ArbitraryCodeChunk]:
chunks_by_id: Dict[int, ArbitraryCodeChunk] = {chunk.id: chunk for chunk in chunks}
cur_id = start_id
self.vprint(f"Separating if statements out of graph starting at {start_id}")
while chunks_by_id[cur_id].next_chunks:
cur_chunk = chunks_by_id[cur_id]
if isinstance(cur_chunk, Loop):
self.vprint(f"Examining loop {cur_chunk.id} body for if statements...")
cur_chunk.chunks = self.__separate_ifs(cur_chunk.id, cur_chunk.chunks, offset_map)
self.vprint(f"Finished examining loop {cur_chunk.id} body for if statements...")
if len(chunks_by_id[cur_id].next_chunks) == 1:
# This is just a goto/chunk, move on to the next one.
cur_id = chunks_by_id[cur_id].next_chunks[0]
continue
if not isinstance(cur_chunk, ByteCodeChunk):
# We should only be looking at bytecode chunks at this point, all other
# types should have a single next chunk.
raise Exception("Logic error!")
last_action = cur_chunk.actions[-1]
if not isinstance(last_action, IfAction):
# This needs, again, to be an if statement.
raise Exception("Logic error!")
if len(chunks_by_id[cur_id].next_chunks) != 2:
# This needs to be an if statement.
raise Exception("Logic error!")
# This should be an if statement. Figure out if it is an if-else or an
# if, and if both branches return.
if_end = self.__find_shallowest_successor(cur_id, chunks_by_id)
# This is a normal if or if-else, let's compile the true and false
# statements.
true_jump_point = offset_map[last_action.jump_if_true_offset]
false_jump_points = [n for n in cur_chunk.next_chunks if n != true_jump_point]
if len(false_jump_points) != 1:
self.vprint(chunks)
self.vprint(f"Jump has {last_action.jump_if_true_offset} which maps to {offset_map[last_action.jump_if_true_offset]}")
self.vprint(f"Jump branches to {cur_chunk.next_chunks}")
raise Exception("Logic error!")
false_jump_point = false_jump_points[0]
if true_jump_point == false_jump_point:
# This should never happen.
raise Exception("Logic error!")
self.vprint(f"Chunk ID {cur_id} is an if statement with true node {true_jump_point} and false node {false_jump_point} and ending at {if_end}")
true_chunks: List[ArbitraryCodeChunk] = []
if true_jump_point != if_end:
self.vprint(f"Gathering true path starting with {true_jump_point} and ending with {if_end} and detecting if statements within it as well.")
# First, grab all the chunks in this if statement body.
true_chunks = self.__gather_chunks(true_jump_point, if_end, chunks_by_id)
self.vprint(true_chunks)
# Delete these chunks from our chunk mapping since we're putting them in an if body.
for chunk in true_chunks:
del chunks_by_id[chunk.id]
# Now, recursively attempt to detect if statements within this chunk as well.
true_chunks = self.__separate_ifs(true_jump_point, true_chunks, offset_map)
false_chunks: List[ArbitraryCodeChunk] = []
if false_jump_point != if_end:
self.vprint(f"Gathering false path starting with {false_jump_point} and ending with {if_end} and detecting if statements within it as well.")
# First, grab all the chunks in this if statement body.
false_chunks = self.__gather_chunks(false_jump_point, if_end, chunks_by_id)
self.vprint(false_chunks)
# Delete these chunks from our chunk mapping since we're putting them in an if body.
for chunk in false_chunks:
del chunks_by_id[chunk.id]
# Now, recursively attempt to detect if statements within this chunk as well.
false_chunks = self.__separate_ifs(false_jump_point, false_chunks, offset_map)
if false_chunks and (not true_chunks):
negate = True
true_chunks = false_chunks
false_chunks = []
if_id = false_jump_point
else:
negate = False
if_id = true_jump_point
if (not true_chunks) and (not false_chunks):
# We should have at least one!
raise Exception("Logic error!")
# Add a new if body that this current chunk points to. At this point, chunks_by_id contains
# none of the chunks in the true or false bodies of the if, so we add it back to the graph
# in the form of an IfBody.
self.vprint(f"Created new IfBody for chunk {cur_id} to point at, ending at {if_id}")
chunks_by_id[if_id] = IfBody(if_id, true_chunks, false_chunks, if_end, cur_id, negate)
chunks_by_id[cur_id].next_chunks = [if_id]
if if_end is not None:
# Skip over the if, we already analyzed it.
cur_id = if_end
else:
# This if statement encompases all the rest of the statements, we're done.
break
self.vprint(f"Finished separating if statements out of graph starting at {start_id}")
return [c for _, c in chunks_by_id.items()]
def __decompile(self) -> str:
# First, we need to construct a control flow graph.
@ -747,9 +1003,15 @@ class ByteCodeDecompiler(VerboseOutput):
self.vprint("Identifying and separating loops...")
chunks_and_loops = self.__separate_loops(chunks, dominators, offset_map)
# Now, break the graph anywhere where we have control
# flow that ends the execution (return, throw, goto end).
self.vprint("Breaking control flow graph on non-returnable statements...")
self.__break_graph(chunks_and_loops, offset_map)
# Now, identify any remaining control flow logic.
self.vprint("Identifying and separating ifs...")
chunks_loops_and_ifs = self.__separate_ifs(chunks_and_loops, offset_map)
start_id = self.__get_entry_block(chunks_and_loops)
chunks_loops_and_ifs = self.__separate_ifs(start_id, chunks_and_loops, offset_map)
# At this point, we *should* have a directed graph where there are no
# backwards refs and every fork has been identified as an if. This means