flax
flax copied to clipboard
[NVIDIA] Rename fp8 custom dtype to `fp32_max_grad`
This PR renames the original fm32 to fp32_max_grad to express the idea of the dtype is used for storing fp32 values and using max for the gradient accumulation.
cc. @nouiz