Implement new Loop and Scan operators
Related to #189
This PR implements a new low level Loop Op which can be easily transpiled to Numba (the Python perform method takes 9 lines, yay to not having to support C in the future).
It also implements a new higher level Scan Op which returns as outputs the last states + intermediate states of a looping operation. This Op cannot be directly evaluated, and must be rewritten as a Loop Op in Python/Numba backends. For the JAX backend it's probably fine to transpile directly from this representation into a lax.scan as the signatures are pretty much identical. That was not done in this PR.
The reason for the two types of outputs, is that they are useful in different contexts. Final states are sometimes all one needs, whereas intermediate states are generally needed for back propagation (not implemented yet). This allows us to choose which one (or both) of the outputs we want during compilation, without having to do complicated graph analysis.
The existing save_mem_new_scan is used to convert a general scan into a loop that only returns the last computed state. It's... pretty complicated (although it also covers cases where more than 1 but less than all steps being requested, but OTOH it can't handle while loops #178):
https://github.com/pymc-devs/pytensor/blob/8ad33179b12a4c0207b2a654badc608e211e8bb9/pytensor/scan/rewriting.py#L1119
Taking that as a reference I would say the new conversion rewrite from Scan to Loop is much much simpler. Most of it is boilerplate code for defining the right trace inputs and new FunctionGraph
Both Ops expect a FunctionGraph as input. This should probably be created by a user-facing helper that accepts a callable like scan does now. ~~That was not done yet, as I first wanted to discuss the general design.~~ Done
Design issues
~1. The current implementation of Loop assumes there are as many states as outputs of the inner function. This does not make sense for mapping or "filling" operations such as filling a tensor with random values. In one of the tests I had to create a dummy x input to accommodate this restriction. Should we use NoneConst to represent outputs that don't feed into the next state? I think there is something similar being done with the old Scan where the outputs_info must explicitly be None in these cases.~
-
Scan and Loop can now take random types as inputs (scan can't return it as a sequence). This makes random seeding much more explicit compared to the old Scan, which was based on default updates of shared variables. However it highlights the awkwardness of the random API when we want to access the next random state. Should we perhaps add a
return_rng_updateto__call__, so that it doesn't hide the next rng state output? -
Do we want to be able to represent empty Loop / Sequences? If so, how should we go about that?
IfElseis one option, but perhaps it would be nice to represent it in the sameLoopOp? -
What do we want to do in terms of inplacing optimizations?
TODO
If people are on board with the approach
- [ ] Implement Numba dispatch
- [x] Implement JAX dispatch
- [ ] Implement L_op and R_op
- [x] Implement friendly user facing functions
- [ ] Decide on which meta-parameters to preserve (
mode,truncate_gradient,reverseand so on) - [x] Add rewrite that replaces
trace[-1]by the first set of outputs (final state). That way we can keep the old API, while retaining the benefit of doing while Scans without tracing when it's not needed.
The current implementation of Loop assumes there are as many states as outputs of the inner function. This does not make sense for mapping or "filling" operations such as filling a tensor with random values. In one of the tests I had to create a dummy x input to accommodate this restriction. Should we use NoneConst to represent outputs that don't feed into the next state? I think there is something similar being done with the old Scan where the outputs_info must explicitly be None in these cases.
Wouldn't a fill loop look something like this?
state = (pt.scalar(0), pt.empty(shape, dtype), rng)
def update(idx, values, rng):
value, rng = rng.normal() # not exactly the api...
values = pt.set_subtensor(values[idx], value)
return (idx + 1, values, rng, idx < maxval)
(and very much need inplace rewrites for good performance...)
Scan and Loop can now take random types as inputs (scan can't return it as a sequence). This makes random seeding much more explicit compared to the old Scan, which was based on default updates of shared variables. However it highlights the awkwardness of the random API when we want to access the next random state. Should we perhaps add a return_rng_update to call, so that it doesn't hide the next rng state output?
Good question... Don't know either :-)
Do we want to be able to represent empty Loop / Sequences? If so, how should we go about that? IfElse is one option, but perhaps it would be nice to represent it in the same Loop Op?
I think one rewrite that get's easier with the if-else-do-while approach would be loop invariant code motion. Let's say we have a loop like
x = bigarray...
if not_empty:
val = 0
do:
val = (val + x.sum()) ** 2
while val < 10
# rewrite to
x = bigarray...
if not_empty:
val = 0
x_sum = x.sum()
do:
val = (val + x_sum) ** 2
while val < 10
we could move x.sum() out of the loop. But with a while loop we can't as easily, because we only want to do x.sum() if the loop is not empty, and where would we then put that computation?
What do we want to do in terms of inplacing optimizations?
Well, I guess we really need those :-) I'm thinking it might be worth it to copy the initial state, and then donate the state to the inner function? And I guess we need to make sure rewrites are actually running on inner graphs as well...
we could move x.sum() out of the loop. But with a while loop we can't as easily, because we only want to do x.sum() if the loop is not empty, and where would we then put that computation?
Why can't we move it even if it's empty? Sum works fine. Are you worried about Ops that we know will fail with empty inputs?
About the filling Ops, yeah I don't see it as a problem anymore. Just felt awkward to create the dummy input when translating from scan to loop. I am okay with it now
That would change the behavior. If we move it out and don't prevent it from being executed, things could fail for instance if there's an assert somewhere, or some other error happens during it's evaluation. Also, it could be potentially very costly (let's say "solve an ode").
(somehow I accidentally edited your comment instead of writing a new one, no clue how, but fixed now)
In my last commit, sequences are demoted from special citizens to just another constant input in the ScanOp. The user facing helper creates the right graph with indexing that is passed to the user provided function.
I have reverted converting the constant inputs to dummies before calling the user function, which allows the example in the jacobian documentation to work, including the one that didn't work before (because both are now equivalent under the hood :))
https://pytensor.readthedocs.io/en/latest/tutorial/gradients.html#computing-the-jacobian
I reverted too much, and I still need to pass dummy inputs as the state variables, since it doesn't make sense for the user function to introspect the graph beyond the initial state (since it's only valid for the initial state)
Added a simple JAX dispatcher, works in the few examples I tried
I just found out about TypedLists in PyTensor. That should allow us to trace any type of Variables, including RandomTypes :exploding_head:
Pushed a couple of commits that rely on this.
Codecov Report
Merging #191 (5bc7070) into main (958cd14) will increase coverage by
0.06%. The diff coverage is89.11%.
Additional details and impacted files
@@ Coverage Diff @@
## main #191 +/- ##
==========================================
+ Coverage 80.03% 80.09% +0.06%
==========================================
Files 170 173 +3
Lines 45086 45435 +349
Branches 9603 9694 +91
==========================================
+ Hits 36085 36392 +307
- Misses 6789 6818 +29
- Partials 2212 2225 +13
| Impacted Files | Coverage Δ | |
|---|---|---|
| pytensor/compile/mode.py | 84.47% <ø> (ø) |
|
| pytensor/loop/basic.py | 81.44% <81.44%> (ø) |
|
| pytensor/loop/op.py | 90.29% <90.29%> (ø) |
|
| pytensor/link/jax/dispatch/__init__.py | 100.00% <100.00%> (ø) |
|
| pytensor/link/jax/dispatch/loop.py | 100.00% <100.00%> (ø) |
|
| pytensor/link/utils.py | 60.30% <100.00%> (+0.12%) |
:arrow_up: |
| pytensor/typed_list/basic.py | 89.27% <100.00%> (+0.38%) |
:arrow_up: |
| pytensor/link/jax/dispatch/extra_ops.py | 74.62% <0.00%> (-20.90%) |
:arrow_down: |
| pytensor/link/jax/dispatch/shape.py | 80.76% <0.00%> (-7.70%) |
:arrow_down: |
| pytensor/link/jax/dispatch/basic.py | 79.03% <0.00%> (-4.84%) |
:arrow_down: |
| ... and 11 more |
This Discourse thread is a great reminder of several Scan design issues that are fixed here: https://discourse.pymc.io/t/hitting-a-weird-error-to-do-with-rngs-in-scan-in-a-custom-function-inside-a-potential/13151/15
Namely:
- Going to the root to find missing non-sequences (instead of using
truncated_graph_inputs - Gradient only works by indexing non-sequences
- Scans are very difficult to manipulate!!!