subclass_zoo icon indicating copy to clipboard operation
subclass_zoo copied to clipboard

[not for land] example of simple FP8 UEX with stateful scaling

Open vkuzo opened this issue 1 year ago • 0 comments

This is an example of implementing basic fp8 support with a Python tensor subclass.

tl;dr;

  1. FP8Tensor is the Python object which contains raw fp8 data (as torch.bits8), a scale, and a flavor (e4m3/e5m2)
  2. FP8Tensor.__torch__dispatch knows how to add gradients, but converts to fp32 for everything else
  3. FP8Linear is a module which can do stateful delayed scaling. User is expected to manually swap their linears to something like this.

Note: E4M3 support has not been numerically validated, and E5M2 support is not there at all Note: No testing other than the bare bones at the bottom of the PR has been done. Note: scaling is not implemented, currently it's just scales of 1.0 everywhere

vkuzo avatar Apr 28 '23 21:04 vkuzo