executorch icon indicating copy to clipboard operation
executorch copied to clipboard

ConvertToLinearPass is not sound when transposes are elided

Open metascroy opened this issue 8 months ago • 3 comments

🐛 Describe the bug

The linear pattern has executorch_exir_dialects_edge__ops_aten_permute_copy_default and executorch_exir_dialects_edge__ops_aten_addmm_default:

input_1 = input
 aten_permute_copy_default: "f32[4, 8]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(p_weight, [1, 0]);  p_weight = None
aten_addmm_default: "f32[1, 8]" = executorch_exir_dialects_edge__ops_aten_addmm_default(p_bias, input_1, aten_permute_copy_default);  p_bias = input_1 = aten_permute_copy_default = None
return (aten_addmm_default,)

The ConvertToLinearPass tries to reconstruct linear from these ops. It correctly does this when permute is present, but incorrectly constructs linear from executorch_exir_dialects_edge__ops_aten_addmm_default when permute is not present. This can happen if you elide the transpose (e.g., with const propagation or RemoveRedundantTransposes).

Repro (modified from backends/xnnpack/test/passes/test_convert_to_linear):

import unittest

import torch
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
from executorch.exir.passes.constant_prop_pass import constant_prop_pass


class TestConvertToLinear(unittest.TestCase):
    PassStage = RunPasses([ConvertToLinearPass])

    def setUp(self):
        torch._dynamo.reset()

    def test_fp32_convert_to_linear(self):
        in_sizes = [1, 4, 4]
        input_sizes = [4, 37, 17]
        output_sizes = [8, 17, 37]
        bias_vals = [True, True, False]

        for i, _ in enumerate(in_sizes):
            torch._dynamo.reset()
            in_size = int(in_sizes[i])
            input_size = int(input_sizes[i])
            output_size = int(output_sizes[i])
            linear = torch.nn.Linear(input_size, output_size, bias=bias_vals[i])
            inputs = (torch.randn(in_size, input_size),)

            to_edge_stage = Tester(linear, inputs).export().to_edge()
            constant_prop_pass(to_edge_stage.stages["ToEdge"].artifact.exported_program())
            (
                to_edge_stage
                .run_passes(self.PassStage)
                .check_count(
                    {"executorch_exir_dialects_edge__ops_aten_linear_default": 1}
                )
                .run_method_and_compare_outputs()
            )

This gives failure:

RuntimeError: a and b must have same reduction dim, but got [1, 4] X [8, 4].

I have seen this issue in the export llama script when linear ops are not delegated to XNNPACK.

Versions

NA

cc @digantdesai @mcr229 @cbilgin

metascroy avatar Apr 26 '25 00:04 metascroy

cc @digantdesai @mcr229

metascroy avatar Apr 26 '25 00:04 metascroy

i see, there have been some issues with convert_to_linear, and in it's place i've been trying to use to_edge_transform_and_lower instead because that's a bit easier to exercise. What is yoru use case for convert_to_linear? I can prioritize the fix if necessary.

mcr229 avatar Apr 28 '25 17:04 mcr229

i see, there have been some issues with convert_to_linear, and in it's place i've been trying to use to_edge_transform_and_lower instead because that's a bit easier to exercise. What is yoru use case for convert_to_linear? I can prioritize the fix if necessary.

I'm not trying to use it, but it is being used in llama builder and I have seen issues with it.

metascroy avatar Apr 28 '25 17:04 metascroy

llama builder uses it because we have an optimized op_linear implementation that gets used for bfloat16 models

swolchok avatar Apr 28 '25 18:04 swolchok

This can happen if you elide the transpose

I think we didn't assume that because we don't remove permute before passing it to xnnpack backend, where this pass is originally written for. But adding this case in the pass should be easy enough.

Alternatively, in some cases, if we have linear implemented in optimized like for bf16, we can just not decompose it - like this instead of using this pass.

digantdesai avatar Apr 29 '25 19:04 digantdesai