jax icon indicating copy to clipboard operation
jax copied to clipboard

Better Handling of Nested Loop with Remat

Open LeoXinhaoLee opened this issue 1 year ago • 0 comments

Hi, the forward pass of our training includes passing through a nested loop with 2 layers (outer loop, inner loop).

The input data X of shape [T, F] is first reshaped into [outer_loop_num, inner_loop_num, F], and its first and second dimensions are being looped over.

Our goal is to use jax.checkpoint on each inner loop, such that all intermediate variables generated inside it won’t be saved.

We have identified two ways so far to do this, each with some potential problems. Please kindly check out this example colab: https://colab.research.google.com/drive/1TLLQbzIdSX1SYSujmGmdrPO28aa3ZyFb#scrollTo=m3aI--XBMINi

Thank you very much for your time and help!

LeoXinhaoLee avatar Apr 23 '24 00:04 LeoXinhaoLee