Add aten.linear.default implementation to mx_ops
Fixes #796
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/806
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 8b79caabb6434c879a95c5b47778a6204d009ece with merge base f5703b07acc683653556d04ef970709ba47dba10 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Hi @Ali-Flt!
Thank you for your pull request and welcome to our community.
Action Required
In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.
Process
In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.
Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.
If you have received this in error or have any questions, please contact us at [email protected]. Thanks!
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!
thanks! Could you also add a test to https://github.com/pytorch/ao/blob/main/test/prototype/mx_formats/test_mx_linear.py ?
@vkuzo I tried different ways to trigger the use of aten.linear but couldn't
I'd definitely recommend a test that fails before this PR and passes after this PR. Would it work to wrap the code snippet you were using (which ended up calling into the linear override) into a test?
Just curious. Would it be better to implement F.linear() under __torch_function__() instead? Previously I also faced strange behavior on what aten ops will be dispatched by F.linear(), so implementing F.linear() directly would solve the problem.
Would it be better to implement F.linear() under torch_function() instead?
As of ~months ago, __torch_dispatch__ was better supported with torch.compile, at least for the things we needed for float8. I haven't checked if torch.compile + __torch_function__ coverage is better now, would be good to check.
+1 Kudos for the fix! Can confirm I also ran into this same error without this PR
A simple test which will fail before this PR on:
torch==2.5.1+cu121
torchao==0.6.1
import torch
import torch.nn as nn
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
class MLP(nn.Module):
def __init__(self, in_features: int = 128, out_features: int = 256):
super().__init__()
self.fc1 = nn.Linear(in_features=in_features, out_features=out_features)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.gelu(x)
return x
model = MLP()
# Does not hit the error with swap_linear_with_mx_inference_linear
swap_linear_with_mx_linear(
model, elem_dtype=torch.float8_e4m3fn, block_size=32)
input_tensor = torch.randn(10, 128)
with torch.inference_mode():
_ = model(input_tensor)
@vkuzo Thanks for the great insight on __torch_dispatch__ vs __torch_function__ this is super helpful