Kirthi Shankar Sivamani
Kirthi Shankar Sivamani
@nouiz `0.4.16` was out on 18th September 2023, so with this change we're effectively supporting only 5 months of JAX releases, which seems too few
@nouiz Could you resolve the conflicts so that we can merge this?
@timmoon10 Have you verified identical numerics with this change?
For testing CUDA graphs with FP8 caching, did you use the `noop_flag` in `transpose` and the `fp8_weight_caching` flag in `make_graphed_callables`?
/te-ci pytorch
@Teng-xu @yongyanrao These extra states are indeed a part of the additional information needed for FP8 training checkpoint. These can be explicitly removed but the simplest method would be to...
How are you currently doing activation checkpointing? Are you using an underlying toolkit such as [NeMo](https://github.com/NVIDIA/NeMo)?
Thanks for spotting this! @knowlsie