transformer_nuggets
transformer_nuggets copied to clipboard
Transpose throws an annoying wrench in the mix
Summary
When you call torch.nn.F.linear() you will call transpose on the weight.
One solution for doing this is to lazily compute the transpose, which is mark that the matrix needs to be transposed when it encounters a realizing op. In this case that would be addmm and mm. Did not handle correctly double even transposes cause I think that's unlikely
https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_ops.py#L44
right now we have this "emulate" field where whenver you write an op you have to make sure that emulate gets set correctly on the newly created Float8Tensor.
I just did something similar for NF4 and that works, but I feels kinda code smelly to me. Its essentially a lazy transpose but since we only need to support these 3 ops it should work.
@rohan-varma Okay I added this and it works, need to add compile tests and the do the memory profiling things