transformer_nuggets icon indicating copy to clipboard operation
transformer_nuggets copied to clipboard

Transpose throws an annoying wrench in the mix

Open drisspg opened this issue 1 year ago • 2 comments
trafficstars

Summary

When you call torch.nn.F.linear() you will call transpose on the weight.

One solution for doing this is to lazily compute the transpose, which is mark that the matrix needs to be transposed when it encounters a realizing op. In this case that would be addmm and mm. Did not handle correctly double even transposes cause I think that's unlikely

drisspg avatar Mar 02 '24 02:03 drisspg

https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_ops.py#L44

right now we have this "emulate" field where whenver you write an op you have to make sure that emulate gets set correctly on the newly created Float8Tensor.

I just did something similar for NF4 and that works, but I feels kinda code smelly to me. Its essentially a lazy transpose but since we only need to support these 3 ops it should work.

drisspg avatar Mar 03 '24 19:03 drisspg

@rohan-varma Okay I added this and it works, need to add compile tests and the do the memory profiling things

drisspg avatar Mar 04 '24 17:03 drisspg