moose
moose copied to clipboard
Make AbstractComputation nest-able
Functions written with the eDSL should be callable from within other computations, regardless of whether they've been wrapped with the pm.computation
decorator. For example,
@pm.computation
def plus1(x: pm.Argument(alice, dtype=pm.float64):
with alice:
one = pm.constant(1, dtype=pm.float64)
return pm.add(x, one)
@pm.computation
def alice_add():
with alice:
x = pm.constant(3, dtype=pm.float64)
x_plus_one = plus1(x)
return x_plus_one
if __name__ == "__main__":
[...]
runtime.set_default()
val = alice_add() # <-- will fail during tracing
When alice_add
is called, current behavior would be the following:
- inside a runtime context,
alice_add.__call__
invokestrace(alice_add)
-
trace(alice_add)
will then invokeplus1.__call__
. in order for this call to succeed,plus1
will need to return anExpression
to be used to trace the rest ofalice_add
. - however, since the default runtime context is not None,
plus1
will be executed against the default runtime'sevaluate_computation
with arguments of typeExpression
- the rust runtime bindings will try to interpret these Expression
pyobj
's as Moose Values, which will fail with a TypeError because these are not concrete values.
One solution for the user is to just drop the pm.computation
decorator from plus1
, so that it returns Expression
no matter what runtime context is around. But this makes it hard for users to use "standard library" computations if they are already decorated with AbstractComputation
(which would likely often be the case).
I think the simplest solution here would be to do the following:
- Inside
pm.trace
, temporarily unset the default runtime context, so thatget_current_runtime
returns None. - If
AbstractComputation.__call__
is invoked without a runtime context (i.e.get_current_runtime
returns None), invokeAbstractComputation.func.__call__
. This invocation maps Expressions to Expressions, so tracing can proceed normally. - If
AbstractComputation.__call__
is invoked inside a runtime context, invokeget_current_runtime().evaluate_computation(...)
with the computation as usual
Some other options:
- Allow for nesting runtime contexts and create a new "dummy" Runtime class whose
evaluate_computation
simply forwards toAbstractComputation.func.__call__
- Something "moose-ier", e.g. accommodate Expression conversion in Moose bindings and in this case execute symbolically, i.e. run computation against a SymbolicSession instead of against the AsyncSession in AsyncTestRuntime