ufl icon indicating copy to clipboard operation
ufl copied to clipboard

DAG visitor overhaul

Open wence- opened this issue 4 years ago • 2 comments

The MultiFunction based map_dag visitors in UFL provide a caching (and hence efficient) API for implementing post-order visits of a UFL DAG. The core function is map_expr_dags (the relevant part is here https://github.com/FEniCS/ufl/blob/master/ufl/corealg/map_dag.py#L64).

Most of the transformations in the UFL pipeline are bottom up (post-order) and so are well-served by this infrastructure. Some, however, are pre-order (or at least partially pre-order). Examples include the restriction propagation (which hand-codes a post-order visit of a restricted subtree https://github.com/FEniCS/ufl/blob/master/ufl/algorithms/apply_restrictions.py#L36)

This can lead to unintended performance bugs where one, by necessity, hand-codes the DAG recursion to get a pre-order visit. This will blow up the perceived size of the DAG, the avoidance of which was one of the major motiviating factors for the ReuseTransformer -> MultiFunction + map_expr_dag transition.

Unfortunately, providing a caching pre-order visit requires that the visitor object sees the cache (the post-order visit cache can be managed completely from the outside). In the latter case the MultiFunction does not implement its own recursion and instead is used for type-based method dispatch on the first argument.

In tsfc, this issue is solved by having a more general visitor implementation (https://github.com/firedrakeproject/tsfc/blob/master/gem/node.py#L187). One writes singledispatch recursive functions that take the cache-manager as a parameter. Recursion is done by __call__ing this parameter on children as necessary.

Here's an example from tsfc using this pattern to manage abs-simplification (from https://github.com/firedrakeproject/tsfc/blob/master/tsfc/ufl_utils.py#L250)

@singledispatch
def _simplify_abs(o, self, in_abs):
    """Single-dispatch function to simplify absolute values.

    :arg o: UFL node
    :arg self: Callback handler for recursion
    :arg in_abs: Is ``o`` inside an absolute value?

    When ``in_abs`` we must return a non-negative value, potentially
    by wrapping the returned node with ``Abs``.
    """
    raise AssertionError("UFL node expected, not %s" % type(o))

@_simplify_abs.register(Expr)
def _simplify_abs_expr(o, self, in_abs):
    # General case, only wrap the outer expression (if necessary)
    operands = [self(op, False) for op in o.ufl_operands]
    result = ufl_reuse_if_untouched(o, *operands)
    if in_abs:
        result = Abs(result)
    return result

...

@_simplify_abs.register(Abs)
def _simplify_abs_abs(o, self, in_abs):
    return self(o.ufl_operands[0], True)

def simplify_abs(expression):
    """Simplify absolute values in a UFL expression.  Its primary
    purpose is to "neutralise" CellOrientation nodes that are
    surrounded by absolute values and thus not at all necessary."""
    mapper = MemoizerArg(_simplify_abs)
    return mapper(expression, False)

The MemoizerArg manages the caching and recursion, the visitor just needs to implement appropriate handlers for all of the nodes (and recurse where necessary).

This supports caching in both pre- and post-order traversal because the dispatch functions see the memoizer object.

We could implement something like this for UFL, I think one can probably provide backwards-compat to the existing MultiFunction setup.

Any other better suggestions? The current functional interface of map_expr_dag precludes stashing the callbacks on the MultiFunction object easily.

wence- avatar Jul 21 '20 10:07 wence-

This seems related to #34, and more precisely to this comment, although thus far not linked.

miklos1 avatar Aug 25 '20 16:08 miklos1

Related to this is also #69. The current setup doesn't provide good caching behaviour for visitors that do dispatch (e.g. in apply_restrictions). #69 provides a band-aid fix to this, but is a little hacky. The problem is that if at some point the recursion needs to pass state-dependent information down to transformed children, there's no way that caching can occur (because the map_expr_dag model doesn't allow for passing arguments down through the call tree).

wence- avatar Oct 15 '21 11:10 wence-