Support custom extraction with cost functions
Currently, the extraction in egglog is rather limited. It does a tree-based extraction (meaning that if a node shows up twice, it will be counted twice) and requires static costs per function.
The first issue, the type of extractor, could be alleviated by using some extractors from extraction gym. The second, having some custom costs per item could be addressed upstream in egglog (https://github.com/egraphs-good/egglog/issues/294) but is not on the immediate roadmap.
Either way, it would also be nice to have fully custom extraction. Being able to iterate through the e-graph and do what you will...
Currently, it's "possible" by serializing the e-graph to JSON. But this is not ideal because then you have to look at JSON with random keys and they might not map to your Python function names and it's not type safe and... Yeah it's just a real pain!
So I think it would make sense to add an interface that allows:
- Using custom extractors from extraction-gym without leaving the Python bindings by building that with egglog.
- Being able to set custom costs per record before extracting.
- Being able to query costs and write your own extractor in Python
- Do all of this while keeping static-type safety.
- Reduce overhead as much as possible in terms of serialization and wrapping/unwrapping.
Possible Design
Here is a possible API design for the extractors
"""
Examples using egglog.
"""
from __future__ import annotations
from typing import Literal, Protocol, TypeVar
from egglog import Expr
EXPR = TypeVar("EXPR", bound=Expr)
class EGraph:
def serialize(self) -> SerializedEGraph:
"""
Serializes the e-graph into a format that can be passed to an extractor.
"""
raise NotImplementedError
def extract(
self, x: EXPR, /, extractor: Extractor, include_cost: Literal["tree", "dag"] | None
) -> EXPR | tuple[EXPR, int]:
"""
Extracts the given expression using the given extractor, optionally including the cost of the extraction.
"""
extract_result = extractor.extract(self.serialize(), [x])
res = extract_result.chosen(x)
if include_cost is None:
return res
cost = extract_result.tree_cost([x]) if include_cost == "tree" else extract_result.dag_cost([x])
return res, cost
class SerializedEGraph:
def equivalent(self, x: EXPR, /) -> list[EXPR]:
"""
Returns all equivalent expressions. i.e. all expressions with the same e-class.
"""
raise NotImplementedError
def self_cost(self, x: Expr, /) -> int:
"""
Returns the cost of just that function, not including its children.
"""
raise NotImplementedError
def set_self_cost(self, x: Expr, cost: int, /) -> None:
"""
Sets the cost of just that function, not including its children.
"""
raise NotImplementedError
class Extractor(Protocol):
def extract(self, egraph: SerializedEGraph, roots: list[Expr]) -> ExtractionResult: ...
class ExtractionResult:
"""
An extraction result is a mapping from an e-class to chosen nodes, paired with the original extracted e-graph.
Based off of https://github.com/egraphs-good/extraction-gym/blob/main/src/extract/mod.rs but removed e-classes
since they are not present in Python bindings and instead let you pass in any member of that e-class and get out
representative nodes.
"""
egraph: SerializedEGraph
def __init__(self, egraph: SerializedEGraph) -> None: ...
def choose(self, class_: EXPR, chosen_node: EXPR, /) -> None:
"""
Choose an expression in the e-graph.
"""
def chosen(self, x: EXPR, /) -> EXPR:
"""
Given an expr that is in the e-graph, it recursively returns the chosen expressions in each e-class.
"""
raise NotImplementedError
def check(self) -> None:
"""
Check the extraction result for consistency.
"""
raise NotImplementedError
def find_cycles(self, roots: list[Expr]) -> list[Expr]:
"""
Returns all classes that are in a cycle, which is reachable from the roots.
"""
raise NotImplementedError
def tree_cost(self, roots: list[Expr]) -> int:
"""
Returns the "tree cost" (counting duplicate nodes twice) of the trees rooted at the given roots.
"""
raise NotImplementedError
def dag_cost(self, roots: list[Expr]) -> int:
"""
Returns the "dag cost" (counting duplicate nodes once) of the dag rooted at the given roots.
"""
raise NotImplementedError
Using this interface, you could use the default costs form egglog and use a custom extractor, as shown in the helper extract method.
However, you could also set custom costs before serializing, overriding any from egglog:
serialized = egraph.serialize()
serialized.set_self_cost(x, 10)
serialized.set_self_cost(y, 100)
extractor.extract(serialized, [x]).chosen(x)
How would you be able to traverse an expression at runtime and see its children? I think with three small additions, we could be able to do this with our current API:
- Primitives: Allow any primitive to be converted to a Python object with, i.e.
int(i64(0)) - User Defined Constants: Allow
bool(eq(x).to(y))which will resolve to whether the two sides are exactly syntactically equal. - User Defined Functions: Support a new way to get the args in a type safe manner based on a function, i.e.
fn_matches(x, f)would return a boolean to say whether the function matches, and thenfn_args(x, f)would return a list of the args. They could be typed like this:
class _FnMatchesBuilder(Generic[EXPR]):
def fn(self, y: Callable[[Unpack[EXPRS]], EXPR], /) -> tuple[Unpack[EXPRS]] | None:
"""
Returns the list of args or None
"""
raise NotImplementedError
EXPRS = TypeVarTuple("EXPRS")
def matches(x: EXPR) -> _FnMatchesBuilder[EXPR]:
raise NotImplementedError
if args := matches(x).fn(y):
x, y, z = args
Alternatively, how would you create a custom extractor? We would want to add one more way to traverse expressions... This time not caring about what particular expression they are, just their args and a way to re-ccreate them with different args. Using that, we could write a simple tree based extractor:
def decompose(x: EXPR, /) -> tuple[ReturnsExpr[EXPR], list[Expr]]:
"""
Decomposes an expression into a callable that will reconstruct it based on its args.
For all expressions, constants or functions, this should hold:
>>> fn, args = decompose(x)
>>> assert fn(*args) == x
This can be used to change the args of a function and then reconstruct it.
If you are looking for a type safe way to deal with a particular constructor, you can use either `eq(x).to(y)` for
constants or `match(x).fn(y)` for functions to get their args in a type safe manner.
"""
raise NotImplementedError
def tree_based_extractor(serialized: SerializedEGraph, expr: EXPR, /) -> tuple[EXPR, int]:
"""
Returns the lowest cost equivalent expression and the cost of that expression, based on a tree based extraction.
"""
min_expr, min_cost = expr_cost(serialized, expr)
for eq in serialized.equivalent(expr):
new_expr, new_cost = expr_cost(serialized, eq)
if new_cost < min_cost:
min_cost = new_cost
min_expr = new_expr
return min_expr, min_cost
def expr_cost(serialized: SerializedEGraph, expr: EXPR, /) -> tuple[EXPR, int]:
"""
Returns the cost of the given expression.
"""
cost = serialized.self_cost(expr)
constructor, children = decompose(expr)
best_children = []
for child in children:
best_child, child_cost = tree_based_extractor(serialized, child)
cost += child_cost
best_children.append(best_child)
return constructor(*best_children), cost
I think we would also need a way to get all parent nodes from a node in the serialized format for doing custom cost traversal... For example you might set some kind of length of a vec, and then want to look that up when computing costs.
For your reference, I wrote a PoC custom extraction with custom cost model at https://github.com/sklam/prototypes_2025/commit/983b81d6e7f69a4444749039f605c670c96bd505 It is heavily inspired by Tensat (https://github.com/uwplse/tensat/blob/master/src/optimize.rs#L57C12-L57C25).
The example expands x ** 4 to x * x * x * x, such that AST size cost will not work. The custom cost-model will penalize the Pow a lot so it will select the Mul variant.
How would you be able to traverse an expression at runtime and see its children? I think with three small additions, we could be able to do this with our current API:
This is currently hard to do and therefore omitted in my PoC. I would want to know that it is a Pow(x, 4) and compute cost knowing the 4.
For short term, is there a way to associate node in the serialized json back to the egglog-python Expr object?
Alternatively, how would you create a custom extractor?
I'm very interested in the decompose() and constructor(). Our workflow is compiler IR -> egglog -> compiler IR. We need the extracted result to be translated back to the IR nodes.
Is using an extractor like an ILP extractor currently supported in upstream egglog? Or support would be needed both there and in the python bindings?
If it is of any interest, I built a small and (hopefully) extensible python library that supports custom extraction and custom cost models. You can implement a cost model or extractor independently, and I plan to port some of extraction-gym extractors for my needs. Then using some ugly deserialization, I convert it back to an egglog object for further processing or code generation.
It is based on @sklam POC, i.e. parsing the serialized JSON and using networkx.
It does not support the matching logic you described in the issue without serializing. I believe the proposed matching might be too "specific", in the sense that I don't always want to match against a full expression (instead of a partial one - i.e. I raise to a power and there is an addition in the exponent). In addition, how does it support cycles in the graph?
It is not fully in the spirit of the library, where everything is strongly typed, but it was necessary to me to support at least a custom cost model that can "peek" to the children of nodes.
Oh cool! Yeah, definitely of interest, feel free to post a link. This is what I would like to work on next, so in particular, if there are test cases or use cases that use the logic you wrote, those would be super helpful to look at to see if I can get something upstreamed here that covers them...
I have been making progress with my POC. Here's a long notebook that goes through compiling a Python function, encoding it into egraph, custom extraction, and MLIR codegen: https://numba.github.io/sealir/demo_geglu_approx_mlir.html#extracting-an-optimized-representation
I have changed the extraction logic since the POC. The previous code mishandles cycles. The new algorithm approximate the solution using an iterative relaxation and it naturally deals with cycles without much complexity: https://github.com/sklam/sealir/blob/05e85a1ed2c7ff4c6f5ec5f7dd901833aa3516ee/sealir/eqsat/rvsdg_extract.py#L145-L202 . For the cost-modeling, my next step is to look into Pareto curve for time-cost vs power-cost (or vs precision). The new algorithm is easier to reason about (for me) in that kind of scenario. So, I'll want the cost to be a ND vector soon.
There is now support for adding a cost per row in #343. THe next step here is to expose the custom cost interface directly in Python to allow even more flexibility to define the cost for a node. Note that the extraction doesn't emit the custom costs yet: https://github.com/egraphs-good/egglog-experimental/issues/23
@sklam I would love to see if I could get something like https://github.com/sklam/sealir/blob/main/sealir/eqsat/rvsdg_extract.py to use the builtin extraction and then the type safe parsing added in https://github.com/egraphs-good/egglog-python/pull/319, so that you don't have to serialize at all or use those private methods. I'll try to take a look, but also if you know of things that would be needed to be able to do that feel free to share.
@saulshanabrook, I wonder if the builtin extraction can maintain the stable cost computation. I have tests for it at https://github.com/numba/sealir/blob/main/sealir/tests/test_cost_extraction.py. IMO, any way that can allow custom python-implemented extraction in a serialization-free manner will be useful in general
Additionally, my team are going to investigate doing cost that are more than just a scalar. We want to investigate using some ML optimization technique to determine the cost-function. We have a case (notebook link) that rewrite tensor expression a @ (b + c) @ d into something that involve np.hstack and np.vstack. Depending on the input shape, the speedup can be >50x but a lot of the cases it is a slow down. We might end up with cost that are ND-vectors. The current "imagination" is that memory (cache effect) can carry a performance bonus or penalty that is non-local to the operation.
^ that's is why i think my team will need a custom extractor.
Thank you this is really helpful!
You might be able to get away with just a custom cost model. Luckily the rust code does support non integer costs.
I'll use this example to guide exposing the custom cost model interface to Python, and also to allow non integer cost values.
@sklam I have the custom cost models now working in python in #357. I wanted to try out the PR by refactoring the sealir cost model to see if I can use it instead.
Does the SealIR extraction routine try to minimize graph cost (i.e. f(a, a) will only count cost of a once) or tree cost? The new builtin extractor in Egglog will return the optimal tree cost, so I was going to try hook that up with the custom cost model to see if it works for you.
Looking at it more it does seem to use the graph cost.
Since this is the case, the builtin extractor won't cut it.
I do wonder if one of the rust implemented generic extraction algorithms in https://github.com/egraphs-good/extraction-gym might work?
However I am not sure if they would because in your cost model there isn't a separate cost per item, some of it is inclusive with what's below it. For example like loop computes a fixed multiple of the cost of what's below it. Which isn't supported in the serialization format.
Do you know if tree cost wouldn't work?
One of the problem i found with tree cost is situation like a ** 2 optimizing to a * a. It should not count cost of a twice. I captured this case at this test
I just released a new version of egglog with a graph based extractor and the ability to use custom cost models. I tried it out in https://github.com/numba/sealir/pull/15, which seems to work on small use cases.
Ideally, the goal is to see if we can remove the need to have to serialize to extract, which isn't the most efficient and also means you lose some type safety.