functorch
functorch copied to clipboard
support scan
it would be really nice to be able to eg take models implemented in jax with jax.lax.scan
and port them over to torch without having to unroll scans over modules
Thanks for the feature request, @GallagherCommaJack. This is definitely on our radar. Out of curiosity (and to serve as test cases), do you have example of models in jax that use scan that you wanted to port over to torch?
I'm not the author, but one very useful test case is a sequential model (like a custom RNN). For example, this page implements a recurrent model using scan
+1 for this. I find it extremely useful for implementing rollouts in RL and RNNs.