TorchSharp
TorchSharp copied to clipboard
`GradScaler` and mixed-precision training
https://pytorch.org/docs/stable/notes/amp_examples.html
Currently, bfloat16 works well without grad scaling. But to use fp16 and fp8 (fp8 - in the future, when the support for Hopper/40XX GPUs lands) one needs to scale gradients.
This would be an awesome contribution from someone who knows how to do this right.