ufl
ufl copied to clipboard
DAG visitor overhaul
The MultiFunction
based map_dag
visitors in UFL provide a caching (and hence efficient) API for implementing post-order visits of a UFL DAG. The core function is map_expr_dags
(the relevant part is here https://github.com/FEniCS/ufl/blob/master/ufl/corealg/map_dag.py#L64).
Most of the transformations in the UFL pipeline are bottom up (post-order) and so are well-served by this infrastructure. Some, however, are pre-order (or at least partially pre-order). Examples include the restriction propagation (which hand-codes a post-order visit of a restricted subtree https://github.com/FEniCS/ufl/blob/master/ufl/algorithms/apply_restrictions.py#L36)
This can lead to unintended performance bugs where one, by necessity, hand-codes the DAG recursion to get a pre-order visit. This will blow up the perceived size of the DAG, the avoidance of which was one of the major motiviating factors for the ReuseTransformer
-> MultiFunction + map_expr_dag
transition.
Unfortunately, providing a caching pre-order visit requires that the visitor object sees the cache (the post-order visit cache can be managed completely from the outside). In the latter case the MultiFunction
does not implement its own recursion and instead is used for type-based method dispatch on the first argument.
In tsfc, this issue is solved by having a more general visitor implementation (https://github.com/firedrakeproject/tsfc/blob/master/gem/node.py#L187). One writes singledispatch
recursive functions that take the cache-manager as a parameter. Recursion is done by __call__
ing this parameter on children as necessary.
Here's an example from tsfc using this pattern to manage abs
-simplification (from https://github.com/firedrakeproject/tsfc/blob/master/tsfc/ufl_utils.py#L250)
@singledispatch
def _simplify_abs(o, self, in_abs):
"""Single-dispatch function to simplify absolute values.
:arg o: UFL node
:arg self: Callback handler for recursion
:arg in_abs: Is ``o`` inside an absolute value?
When ``in_abs`` we must return a non-negative value, potentially
by wrapping the returned node with ``Abs``.
"""
raise AssertionError("UFL node expected, not %s" % type(o))
@_simplify_abs.register(Expr)
def _simplify_abs_expr(o, self, in_abs):
# General case, only wrap the outer expression (if necessary)
operands = [self(op, False) for op in o.ufl_operands]
result = ufl_reuse_if_untouched(o, *operands)
if in_abs:
result = Abs(result)
return result
...
@_simplify_abs.register(Abs)
def _simplify_abs_abs(o, self, in_abs):
return self(o.ufl_operands[0], True)
def simplify_abs(expression):
"""Simplify absolute values in a UFL expression. Its primary
purpose is to "neutralise" CellOrientation nodes that are
surrounded by absolute values and thus not at all necessary."""
mapper = MemoizerArg(_simplify_abs)
return mapper(expression, False)
The MemoizerArg
manages the caching and recursion, the visitor just needs to implement appropriate handlers for all of the nodes (and recurse where necessary).
This supports caching in both pre- and post-order traversal because the dispatch functions see the memoizer object.
We could implement something like this for UFL, I think one can probably provide backwards-compat to the existing MultiFunction
setup.
Any other better suggestions? The current functional interface of map_expr_dag
precludes stashing the callbacks on the MultiFunction
object easily.
This seems related to #34, and more precisely to this comment, although thus far not linked.
Related to this is also #69. The current setup doesn't provide good caching behaviour for visitors that do dispatch (e.g. in apply_restrictions
). #69 provides a band-aid fix to this, but is a little hacky. The problem is that if at some point the recursion needs to pass state-dependent information down to transformed children, there's no way that caching can occur (because the map_expr_dag
model doesn't allow for passing arguments down through the call tree).