alpa
alpa copied to clipboard
[WIP][PERF] offload rng computation in remat
In jax.remat
, constant values and random numbers are generated in the forward part and stored until the backward part. An example is this. To reduce memory consumption, we remat this part as well. This PR:
- Move
offload_remat
toutil.py
as an essential part after trace jaxpr inprocess_remat
. - Store the rng result and reuse it as an input in the offloaded part,
- Add test cases for remat with rng seed.