ao icon indicating copy to clipboard operation
ao copied to clipboard

NotImplementedError: aten.linear.default not implemented when using MXTensor

Open Ali-Flt opened this issue 1 year ago • 7 comments

Hey I'm using the MX datatypes. It seems like the aten.linear.default function has not been implemented which causes the linear layers in the attenion layers not work with the MX datatypes.

Can you please implement this function in mx_ops.py? Thanks!

Ali-Flt avatar Sep 03 '24 15:09 Ali-Flt

does this look like a correct implementation?

@implements([aten.linear.default])
def mx_mm(aten_op, args, kwargs=None):
    a = args[0]
    b = args[1]
    if len(args) > 2:
        c = args[2]
    else:
        c = None
    assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
    a_hp = a.to_dtype(a._orig_dtype)
    b_hp = b.to_dtype(b._orig_dtype)
    res = aten_op(a_hp, b_hp, c)
    return res

Ali-Flt avatar Sep 03 '24 17:09 Ali-Flt

Would you have a repro of what specifically is not working for you?

We do have overrides for aten.mm and aten.addmm here: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/mx_ops.py, they are called into from the __torch_dispatch__ extension point.

vkuzo avatar Sep 03 '24 18:09 vkuzo

@vkuzo I did see the aten.mm and aten.addmm implementations. But for some reason in my case when F.linear() is called in MXLinear, aten.linear.default is used instead of atten.addmm.

I don't know what decides which of aten.addmm or aten.linear getting called but nonetheless, having a atten.linear.default implementation should be an easy fix right?

Ali-Flt avatar Sep 04 '24 01:09 Ali-Flt

I don't know what decides which of aten.addmm or aten.linear getting called

Agreed. Could you share a repro so we can dig into why you aren't hitting the mm/addmm functions? Adding a linear implementation sounds potentially reasonable, I just wanted to understand in more detail what exactly you are doing to hit this condition.

vkuzo avatar Sep 04 '24 15:09 vkuzo

@vkuzo I was using MXLinear in quantizing the inference of Llama3.1-8B. Maybe the reason could be that I was calling F.linear in a torch.no_grad() context?

I finally decided to avoid calling any operation in MX format after all, so I don't have the code I encountered the error with anymore. I do the quantization this way now (Please note that I don't care about the memory optimization when quantizing. I just want to incorporate the quantization errors. Hence, I bring weights and activations back to their original dtypes):

...
#Quantizing weights:
orig_dtype = linear_layer.weight.data.dtype
weight_float = linear_layer.weight.data.float()
weight_q = MXTensor.to_mx(weight_float, self.quant_dtype, self.group_size)
linear_layer.weight.data = weight_q.to_dtype(orig_dtype)
...
    def forward(self, x):
        #Quantizing activations
        x_float = x.float() #MXTensor only accepts float32 and bfloat16
        x_q = MXTensor.to_mx(x_float, self.elem_dtype, self.block_size)
        x_q = x_q.to_dtype(self.weight.dtype)
        y = F.linear(x_q, self.weight, self.bias)
        y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size)
        return y

Ali-Flt avatar Sep 04 '24 16:09 Ali-Flt

I see, thanks for that context. Adding an override for linear sgtm, let me know if you are interested in putting up a PR, otherwise we can take care of it. Thanks for the report!

vkuzo avatar Sep 04 '24 16:09 vkuzo

@vkuzo Created the PR :+1:

Ali-Flt avatar Sep 04 '24 16:09 Ali-Flt