pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Implement new Loop and Scan operators

Open ricardoV94 opened this issue 2 years ago • 8 comments

Related to #189

This PR implements a new low level Loop Op which can be easily transpiled to Numba (the Python perform method takes 9 lines, yay to not having to support C in the future).

It also implements a new higher level Scan Op which returns as outputs the last states + intermediate states of a looping operation. This Op cannot be directly evaluated, and must be rewritten as a Loop Op in Python/Numba backends. For the JAX backend it's probably fine to transpile directly from this representation into a lax.scan as the signatures are pretty much identical. That was not done in this PR.

The reason for the two types of outputs, is that they are useful in different contexts. Final states are sometimes all one needs, whereas intermediate states are generally needed for back propagation (not implemented yet). This allows us to choose which one (or both) of the outputs we want during compilation, without having to do complicated graph analysis.

The existing save_mem_new_scan is used to convert a general scan into a loop that only returns the last computed state. It's... pretty complicated (although it also covers cases where more than 1 but less than all steps being requested, but OTOH it can't handle while loops #178):

https://github.com/pymc-devs/pytensor/blob/8ad33179b12a4c0207b2a654badc608e211e8bb9/pytensor/scan/rewriting.py#L1119

Taking that as a reference I would say the new conversion rewrite from Scan to Loop is much much simpler. Most of it is boilerplate code for defining the right trace inputs and new FunctionGraph


Both Ops expect a FunctionGraph as input. This should probably be created by a user-facing helper that accepts a callable like scan does now. ~~That was not done yet, as I first wanted to discuss the general design.~~ Done

Design issues

~1. The current implementation of Loop assumes there are as many states as outputs of the inner function. This does not make sense for mapping or "filling" operations such as filling a tensor with random values. In one of the tests I had to create a dummy x input to accommodate this restriction. Should we use NoneConst to represent outputs that don't feed into the next state? I think there is something similar being done with the old Scan where the outputs_info must explicitly be None in these cases.~

  1. Scan and Loop can now take random types as inputs (scan can't return it as a sequence). This makes random seeding much more explicit compared to the old Scan, which was based on default updates of shared variables. However it highlights the awkwardness of the random API when we want to access the next random state. Should we perhaps add a return_rng_update to __call__, so that it doesn't hide the next rng state output?

  2. Do we want to be able to represent empty Loop / Sequences? If so, how should we go about that? IfElse is one option, but perhaps it would be nice to represent it in the same Loop Op?

  3. What do we want to do in terms of inplacing optimizations?

TODO

If people are on board with the approach

  • [ ] Implement Numba dispatch
  • [x] Implement JAX dispatch
  • [ ] Implement L_op and R_op
  • [x] Implement friendly user facing functions
  • [ ] Decide on which meta-parameters to preserve (mode, truncate_gradient, reverse and so on)
  • [x] Add rewrite that replaces trace[-1] by the first set of outputs (final state). That way we can keep the old API, while retaining the benefit of doing while Scans without tracing when it's not needed.

ricardoV94 avatar Jan 10 '23 14:01 ricardoV94

The current implementation of Loop assumes there are as many states as outputs of the inner function. This does not make sense for mapping or "filling" operations such as filling a tensor with random values. In one of the tests I had to create a dummy x input to accommodate this restriction. Should we use NoneConst to represent outputs that don't feed into the next state? I think there is something similar being done with the old Scan where the outputs_info must explicitly be None in these cases.

Wouldn't a fill loop look something like this?

state = (pt.scalar(0), pt.empty(shape, dtype), rng)
def update(idx, values, rng):
    value, rng = rng.normal()  # not exactly the api...
    values = pt.set_subtensor(values[idx], value)
    return (idx + 1, values, rng, idx < maxval)

(and very much need inplace rewrites for good performance...)

Scan and Loop can now take random types as inputs (scan can't return it as a sequence). This makes random seeding much more explicit compared to the old Scan, which was based on default updates of shared variables. However it highlights the awkwardness of the random API when we want to access the next random state. Should we perhaps add a return_rng_update to call, so that it doesn't hide the next rng state output?

Good question... Don't know either :-)

Do we want to be able to represent empty Loop / Sequences? If so, how should we go about that? IfElse is one option, but perhaps it would be nice to represent it in the same Loop Op?

I think one rewrite that get's easier with the if-else-do-while approach would be loop invariant code motion. Let's say we have a loop like

x = bigarray...
if not_empty:
    val = 0
    do:
        val = (val + x.sum()) ** 2
    while val < 10

# rewrite to
x = bigarray...
if not_empty:
    val = 0
    x_sum = x.sum()
    do:
        val = (val + x_sum) ** 2
    while val < 10

we could move x.sum() out of the loop. But with a while loop we can't as easily, because we only want to do x.sum() if the loop is not empty, and where would we then put that computation?

What do we want to do in terms of inplacing optimizations?

Well, I guess we really need those :-) I'm thinking it might be worth it to copy the initial state, and then donate the state to the inner function? And I guess we need to make sure rewrites are actually running on inner graphs as well...

aseyboldt avatar Jan 11 '23 22:01 aseyboldt

we could move x.sum() out of the loop. But with a while loop we can't as easily, because we only want to do x.sum() if the loop is not empty, and where would we then put that computation?

Why can't we move it even if it's empty? Sum works fine. Are you worried about Ops that we know will fail with empty inputs?

About the filling Ops, yeah I don't see it as a problem anymore. Just felt awkward to create the dummy input when translating from scan to loop. I am okay with it now

ricardoV94 avatar Jan 12 '23 07:01 ricardoV94

That would change the behavior. If we move it out and don't prevent it from being executed, things could fail for instance if there's an assert somewhere, or some other error happens during it's evaluation. Also, it could be potentially very costly (let's say "solve an ode").

(somehow I accidentally edited your comment instead of writing a new one, no clue how, but fixed now)

aseyboldt avatar Jan 12 '23 16:01 aseyboldt

In my last commit, sequences are demoted from special citizens to just another constant input in the ScanOp. The user facing helper creates the right graph with indexing that is passed to the user provided function.

I have reverted converting the constant inputs to dummies before calling the user function, which allows the example in the jacobian documentation to work, including the one that didn't work before (because both are now equivalent under the hood :))

https://pytensor.readthedocs.io/en/latest/tutorial/gradients.html#computing-the-jacobian

I reverted too much, and I still need to pass dummy inputs as the state variables, since it doesn't make sense for the user function to introspect the graph beyond the initial state (since it's only valid for the initial state)

ricardoV94 avatar Jan 12 '23 16:01 ricardoV94

Added a simple JAX dispatcher, works in the few examples I tried

ricardoV94 avatar Jan 13 '23 16:01 ricardoV94

I just found out about TypedLists in PyTensor. That should allow us to trace any type of Variables, including RandomTypes :exploding_head:

Pushed a couple of commits that rely on this.

ricardoV94 avatar Jan 16 '23 10:01 ricardoV94

Codecov Report

Merging #191 (5bc7070) into main (958cd14) will increase coverage by 0.06%. The diff coverage is 89.11%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #191      +/-   ##
==========================================
+ Coverage   80.03%   80.09%   +0.06%     
==========================================
  Files         170      173       +3     
  Lines       45086    45435     +349     
  Branches     9603     9694      +91     
==========================================
+ Hits        36085    36392     +307     
- Misses       6789     6818      +29     
- Partials     2212     2225      +13     
Impacted Files Coverage Δ
pytensor/compile/mode.py 84.47% <ø> (ø)
pytensor/loop/basic.py 81.44% <81.44%> (ø)
pytensor/loop/op.py 90.29% <90.29%> (ø)
pytensor/link/jax/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/jax/dispatch/loop.py 100.00% <100.00%> (ø)
pytensor/link/utils.py 60.30% <100.00%> (+0.12%) :arrow_up:
pytensor/typed_list/basic.py 89.27% <100.00%> (+0.38%) :arrow_up:
pytensor/link/jax/dispatch/extra_ops.py 74.62% <0.00%> (-20.90%) :arrow_down:
pytensor/link/jax/dispatch/shape.py 80.76% <0.00%> (-7.70%) :arrow_down:
pytensor/link/jax/dispatch/basic.py 79.03% <0.00%> (-4.84%) :arrow_down:
... and 11 more

codecov-commenter avatar Jan 20 '23 16:01 codecov-commenter

This Discourse thread is a great reminder of several Scan design issues that are fixed here: https://discourse.pymc.io/t/hitting-a-weird-error-to-do-with-rngs-in-scan-in-a-custom-function-inside-a-potential/13151/15

Namely:

  • Going to the root to find missing non-sequences (instead of using truncated_graph_inputs
  • Gradient only works by indexing non-sequences
  • Scans are very difficult to manipulate!!!

ricardoV94 avatar Oct 23 '23 10:10 ricardoV94