NotImplementedError: aten.linear.default not implemented when using MXTensor
Hey I'm using the MX datatypes. It seems like the aten.linear.default function has not been implemented which causes the linear layers in the attenion layers not work with the MX datatypes.
Can you please implement this function in mx_ops.py? Thanks!
does this look like a correct implementation?
@implements([aten.linear.default])
def mx_mm(aten_op, args, kwargs=None):
a = args[0]
b = args[1]
if len(args) > 2:
c = args[2]
else:
c = None
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
a_hp = a.to_dtype(a._orig_dtype)
b_hp = b.to_dtype(b._orig_dtype)
res = aten_op(a_hp, b_hp, c)
return res
Would you have a repro of what specifically is not working for you?
We do have overrides for aten.mm and aten.addmm here: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/mx_ops.py, they are called into from the __torch_dispatch__ extension point.
@vkuzo I did see the aten.mm and aten.addmm implementations. But for some reason in my case when F.linear() is called in MXLinear, aten.linear.default is used instead of atten.addmm.
I don't know what decides which of aten.addmm or aten.linear getting called but nonetheless, having a atten.linear.default implementation should be an easy fix right?
I don't know what decides which of aten.addmm or aten.linear getting called
Agreed. Could you share a repro so we can dig into why you aren't hitting the mm/addmm functions? Adding a linear implementation sounds potentially reasonable, I just wanted to understand in more detail what exactly you are doing to hit this condition.
@vkuzo I was using MXLinear in quantizing the inference of Llama3.1-8B. Maybe the reason could be that I was calling F.linear in a torch.no_grad() context?
I finally decided to avoid calling any operation in MX format after all, so I don't have the code I encountered the error with anymore. I do the quantization this way now (Please note that I don't care about the memory optimization when quantizing. I just want to incorporate the quantization errors. Hence, I bring weights and activations back to their original dtypes):
...
#Quantizing weights:
orig_dtype = linear_layer.weight.data.dtype
weight_float = linear_layer.weight.data.float()
weight_q = MXTensor.to_mx(weight_float, self.quant_dtype, self.group_size)
linear_layer.weight.data = weight_q.to_dtype(orig_dtype)
...
def forward(self, x):
#Quantizing activations
x_float = x.float() #MXTensor only accepts float32 and bfloat16
x_q = MXTensor.to_mx(x_float, self.elem_dtype, self.block_size)
x_q = x_q.to_dtype(self.weight.dtype)
y = F.linear(x_q, self.weight, self.bias)
y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size)
return y
I see, thanks for that context. Adding an override for linear sgtm, let me know if you are interested in putting up a PR, otherwise we can take care of it. Thanks for the report!
@vkuzo Created the PR :+1: