dynamax
dynamax copied to clipboard
Fix `run_sgd` to extract iterator from compiled scan
The _sample_minibatches function produces an generator that, under the hood, mutates an internal state. This doesn't play nicely within lax.scan. Instead, it silently fails and returns the first value of the iterator each time.
Here's a minimal script to repro the failure: https://gist.github.com/slinderman/65a1d55697972d5d766a88425428f633.
I think a reasonable solution is to only jit the body of the innermost for loop, and have the for loops themselves implemented in python.
Pros:
- let's us use general data loaders, like those from PyTorch. @ezhang94 used this in her stochastic EM implementation (https://github.com/probml/ssm-jax/pull/124)
Cons:
- Can't vmap the
run_sgdfunction anymore...