jax
jax copied to clipboard
Better Handling of Nested Loop with Remat
Hi, the forward pass of our training includes passing through a nested loop with 2 layers (outer loop, inner loop).
The input data X of shape [T, F] is first reshaped into [outer_loop_num, inner_loop_num, F], and its first and second dimensions are being looped over.
Our goal is to use jax.checkpoint on each inner loop, such that all intermediate variables generated inside it won’t be saved.
We have identified two ways so far to do this, each with some potential problems. Please kindly check out this example colab: https://colab.research.google.com/drive/1TLLQbzIdSX1SYSujmGmdrPO28aa3ZyFb#scrollTo=m3aI--XBMINi
Thank you very much for your time and help!