functorch icon indicating copy to clipboard operation
functorch copied to clipboard

support scan

Open GallagherCommaJack opened this issue 2 years ago • 3 comments

it would be really nice to be able to eg take models implemented in jax with jax.lax.scan and port them over to torch without having to unroll scans over modules

GallagherCommaJack avatar Sep 30 '22 18:09 GallagherCommaJack

Thanks for the feature request, @GallagherCommaJack. This is definitely on our radar. Out of curiosity (and to serve as test cases), do you have example of models in jax that use scan that you wanted to port over to torch?

zou3519 avatar Oct 03 '22 14:10 zou3519

I'm not the author, but one very useful test case is a sequential model (like a custom RNN). For example, this page implements a recurrent model using scan

hbenazha avatar Jan 05 '23 17:01 hbenazha

+1 for this. I find it extremely useful for implementing rollouts in RL and RNNs.

subho406 avatar Feb 13 '23 08:02 subho406