subclass_zoo
subclass_zoo copied to clipboard
[not for land] example of simple FP8 UEX with stateful scaling
This is an example of implementing basic fp8 support with a Python tensor subclass.
tl;dr;
- FP8Tensor is the Python object which contains raw fp8 data (as torch.bits8), a scale, and a flavor (e4m3/e5m2)
- FP8Tensor.__torch__dispatch knows how to add gradients, but converts to fp32 for everything else
- FP8Linear is a module which can do stateful delayed scaling. User is expected to manually swap their linears to something like this.
Note: E4M3 support has not been numerically validated, and E5M2 support is not there at all Note: No testing other than the bare bones at the bottom of the PR has been done. Note: scaling is not implemented, currently it's just scales of 1.0 everywhere