pytensor
pytensor copied to clipboard
Implement proper JAX/Numba dispatch for IfElse
Description
IfElse is only lazy in the default backend because the function virtual machine handles it (via the "lazy" attribute"). In Numba/JAX it currently does nothing, because it receives all outputs pre-computed.
During compilation we could specialize IfElse into a LazyIfElse Op that contains two inner Graphs, one corresponding to each branch. These graphs should contain all variables that lead to the inputs of IfElse and are not used by any other output variable other than through the outputs of IfElse. This depends on which function is being compiled and can't be known ahead of time.
The current implementation of jax_funcify_IfElse: https://github.com/pymc-devs/pytensor/blob/4235ccc3f4243c5179178a206c15d84c4cda2e79/pytensor/link/jax/dispatch/basic.py#L60-L70
Would instead look something like (pseudo-code):
@jax_funcify.register(LazyIfElse)
def jax_funcify_LazyIfElse(op, **kwargs):
true_fn = jax_funcify(op.true_fgraph)
false_fn = jax_funcify(op.false_fgraph)
def ifelse(cond, *args):
res = jax.lax.cond(cond, true_fn, false_fn, *args)
return res if n_outs > 1 else res[0]
return ifelse
This could even provide a nicer dprint, by showing the two inner graphs. Right now it's not always obvious what operations are lazily computed or not.