alpa icon indicating copy to clipboard operation
alpa copied to clipboard

[WIP][PERF] offload rng computation in remat

Open ZYHowell opened this issue 2 years ago • 0 comments

In jax.remat, constant values and random numbers are generated in the forward part and stored until the backward part. An example is this. To reduce memory consumption, we remat this part as well. This PR:

  • Move offload_remat to util.py as an essential part after trace jaxpr in process_remat.
  • Store the rng result and reuse it as an input in the offloaded part,
  • Add test cases for remat with rng seed.

ZYHowell avatar Jun 22 '22 07:06 ZYHowell