dynamax icon indicating copy to clipboard operation
dynamax copied to clipboard

Fix `run_sgd` to extract iterator from compiled scan

Open slinderman opened this issue 3 years ago • 0 comments

The _sample_minibatches function produces an generator that, under the hood, mutates an internal state. This doesn't play nicely within lax.scan. Instead, it silently fails and returns the first value of the iterator each time.

Here's a minimal script to repro the failure: https://gist.github.com/slinderman/65a1d55697972d5d766a88425428f633.

I think a reasonable solution is to only jit the body of the innermost for loop, and have the for loops themselves implemented in python.

Pros:

  • let's us use general data loaders, like those from PyTorch. @ezhang94 used this in her stochastic EM implementation (https://github.com/probml/ssm-jax/pull/124)

Cons:

  • Can't vmap the run_sgd function anymore...

slinderman avatar Aug 07 '22 00:08 slinderman