torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[TORCH] Add support for aten.rms_norm op

Open sharavana20 opened this issue 10 months ago • 2 comments

I would like to request support for the torch.rms_norm operation in the Torch dialect of Torch-MLIR.

I tested with the torch.rms_norm using fx.export_and_import and the reproduced error is

Image

Minimal Reproduction

def run(f):
    print(f"{f.__name__}")
    print("-" * len(f.__name__))
    f()
    print()

@run
def test_rms_norm():
    class RMSNorm(nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self,x):
            normalized_shape=[3,4]
            input,weight=x
            return torch.rms_norm(input,normalized_shape,weight,eps=0.8)
    exported=fx.export_and_import(RMSNorm(),(torch.randn(1,2,3,4),torch.randn(3,4)),output_type='torch')
    print(exported)

sharavana20 avatar May 26 '25 06:05 sharavana20

@vivekkhandelwal1 I would like to take this up and implement it.

sharavana20 avatar May 26 '25 06:05 sharavana20

@vivekkhandelwal1 I would like to take this up and implement it.

@sharavana20 Assigned it to you.

vivekkhandelwal1 avatar May 27 '25 05:05 vivekkhandelwal1