dewolf
dewolf copied to clipboard
Revive TypePropagation
Proposal
As of now, the TypePropagation
stage hardly does anything. For example, we do not propagate the type of a pointer to the type it points on.
We should fix this and make this pipelinestage more functional.
Approach
Code from previous approaches:
"""Module implementing horizontal type propagation as a pipeline stage."""
from __future__ import annotations
from collections import Counter
from typing import Iterator, List, Tuple, Optional
from dewolf.pipeline.stage import PipelineStage
from dewolf.structures.graphs.nxgraph import NetworkXGraph
from dewolf.structures.graphs.basic import BasicEdge, BasicNode
from dewolf.structures.graphs.cfg import ControlFlowGraph
from dewolf.structures.pseudo.expressions import Expression, DataflowObject
from dewolf.structures.pseudo.operations import Condition, UnaryOperation, OperationType
from dewolf.structures.pseudo.instructions import Instruction
from dewolf.structures.pseudo.typing import CustomType, Float, Integer, Pointer, Type, UnknownType
from dewolf.task import DecompilerTask
class TypeNode(BasicNode):
def __init__(self, references: List[Expression]):
super().__init__(id(self))
self.references = references
def __contains__(self, item: Expression) -> bool:
return item in self.references
def __len__(self) -> int:
return len(self.references)
class TypeRelation(BasicEdge):
def __init__(self, source: TypeNode, sink: TypeNode, level: int):
super().__init__(source, sink)
self._level = level
@property
def level(self) -> int:
return self._level
class SubExpressionGraph(NetworkXGraph):
"""Graph class modeling type-relations between expressions."""
@classmethod
def from_cfg(cls, cfg: ControlFlowGraph) -> SubExpressionGraph:
"""Generate a TypeGraph by parsing the given ControlFlowGraph."""
graph = cls()
for instruction in cfg.instructions:
graph.add_instruction(instruction)
return graph
def add_instruction(self, instruction: Instruction):
worklist: List[Tuple[DataflowObject, DataflowObject]] = [(instruction, subexpression) for subexpression in instruction]
while worklist:
parent, expression = worklist.pop()
self.add_edge(BasicEdge(parent, expression))
for child in expression:
worklist.append((expression, child))
class TypeGraph(NetworkXGraph):
REFERENCE_TYPE = {
OperationType.dereference: -1,
OperationType.address: 1
}
def propagate_type(self):
type = self.get_type()
node_data = list(self.iterate_node_levels())
level_types = {level: self._apply_reference_level(type, level) for node, level in node_data}
for node, level in node_data:
print(f'type: {type}, level: {level} - {", ".join((str(x) for x in node.references))}')
for expression in node.references:
expression._type = level_types[level]
def get_type(self) -> Type:
types = list(self._collect_types())
return TypePropagation.find_common_type(types)
def get_node_containing(self, expression: Expression) -> Optional[TypeNode]:
for node in self:
if expression in node:
return node
def iterate_node_levels(self):
assert len(self.get_roots()) == 1
todo = [(0, self.get_roots()[0])]
while todo:
level, node = todo.pop()
yield node, level
for edge in self.get_out_edges(node):
todo.append((level + edge.reference_offset, edge.sink))
def _collect_types(self):
for node, level in self.iterate_node_levels():
for reference in node.references:
yield self._apply_reference_level(reference.type, -level)
def _apply_reference_level(self, type: Type, level: int) -> Type:
if level == 0:
return type
deref = type.copy()
if level < 0:
for i in range(-level):
assert isinstance(deref, Pointer), "Can only dereference pointer types"
deref = deref.type
else:
for i in range(level):
deref = Pointer(deref.type)
return deref
@classmethod
def from_expression_graph(cls, graph: SubExpressionGraph):
reference_edges = list(cls.find_reference_edges(graph))
graph.remove_edges_from(reference_edges)
typegraph = cls()
for i, component in enumerate(graph.iter_components()):
typegraph.add_node(TypeNode((node for node in component if not isinstance(node, Instruction))))
for edge in reference_edges:
source = typegraph.get_node_containing(edge.source)
sink = typegraph.get_node_containing(edge.sink)
reference_level = cls.REFERENCE_TYPE[edge.source.operation]
inverted_edge = typegraph.get_edge(sink, source)
if not inverted_edge or inverted_edge.reference < reference_level:
typegraph.add_edge(TypeRelation(source, sink, reference_level))
if inverted_edge:
typegraph.remove_edge(inverted_edge)
return typegraph
@staticmethod
def find_reference_edges(graph: SubExpressionGraph) -> Iterator[BasicEdge]:
return filter(
lambda edge: isinstance(edge.source, UnaryOperation) and edge.source in [OperationType.dereference, OperationType.address],
graph.edges
)
class TypePropagation(PipelineStage):
"""Implements type propagation based on a set of heuristics."""
name = "type-propagation"
def run(self, task: DecompilerTask):
"""
Run type propagation on the given task object.
We assume there are two types of propagation: Assignment (horizontal) and Operation (vertical).
Operation-Level propagation is directly implemented into pseudo.Operands through a recursive lookup.
"""
subexpression_graph = SubExpressionGraph.from_cfg(task.graph)
for type_graph in self.split_graph(subexpression_graph):
print(f'TypeGraph: {len(type_graph)}')
type_graph.propagate_type(task.graph)
def split_graph(self, graph: SubExpressionGraph) -> Iterator[TypeGraph]:
self.cut_non_type_edges(graph)
for component in list(graph.iter_components()):
print(", ".join((str(x) for x in component)))
subgraph = graph.subgraph(component)
yield TypeGraph.from_expression_graph(subgraph)
def cut_non_type_edges(self, graph: SubExpressionGraph):
for edge in list(graph.edges):
if self.blocks_type_propagation(edge.source):
graph.remove_edge(edge)
@staticmethod
def blocks_type_propagation(expression: Expression) -> bool:
return isinstance(expression, Condition) or (
isinstance(expression, UnaryOperation) and expression.operation == OperationType.cast)
@staticmethod
def find_common_type(types) -> Type:
histogram = Counter(types)
most_common_types = sorted(histogram.keys(), reverse=True, key=lambda x: (histogram[x], str(x)))
if UnknownType() in most_common_types:
most_common_types.remove(UnknownType())
for filtered_type in filter(TypePropagation.is_non_primitive_type, most_common_types):
return filtered_type
if most_common_types:
return most_common_types[0]
return UnknownType()
@staticmethod
def is_non_primitive_type(type: Type) -> bool:
"""Check if the given type is primitive, so ew can ignore it."""
if isinstance(type, Integer) and not isinstance(type, Float):
return False
if isinstance(type, Pointer) and type.type == CustomType.void():
return False
return True