iree icon indicating copy to clipboard operation
iree copied to clipboard

Concatting tensors causes resulting values to be different

Open Alex-Vasile opened this issue 6 months ago • 7 comments

What happened?

Hard to describe exactly what this bug is related to since it appears to come from later in the compile stage and appears in multiple forms, but the gist is that the same code run through pytorch and through iree produces substatially different results.

I have attempted to shrink the size of the mlir as much as possible, unfortunately it's still ~500 lines after compiling to input.

This bug got isolated down to a function that looks conceptually like:

def attn(b):
    x = f(b)
    ...
    a = ...(x, ...)
    return a + b

This also produced wildly different results. Two cases were created (which map to the two MLIR files provided) to help with finding the issue: A case (passing.mlir) that produced fairly close results:

def attn(b):
    x = f(b)
    ...
    a = ...(x, ...)
    return torch.cat((a, a), dim=-1)

A case (failing.mlir) that produced widly different results:

def attn(b):
    x = f(b)
    ...
    a = ...(x, ...)
    return torch.cat((a, b), dim=-1)  # concat with b

The different between these two cases at the mlir level (both the provided mlir and after compiling to input) is trivial:

721c721
<     %220 = torch.prim.ListConstruct %219, %219 : (!torch.vtensor<[3,?,?],f32>, !torch.vtensor<[3,?,?],f32>) -> !torch.list<vtensor>
---
>     %220 = torch.prim.ListConstruct %219, %43 : (!torch.vtensor<[3,?,?],f32>, !torch.vtensor<[3,?,32],f32>) -> !torch.list<vtensor>
724c724
<     torch.bind_symbolic_shape %221, [%10], affine_map<()[s0] -> (3, s0 * 32, s0 * 64)> : !torch.vtensor<[3,?,?],f32>
---
>     torch.bind_symbolic_shape %221, [%10], affine_map<()[s0] -> (3, s0 * 32, s0 * 32 + 32)> : !torch.vtensor<[3,?,?],f32>

The values of a and b in the failing case should be the same, within tolerance, to the reference results. They are not. The reference results contain results form the torch.cat((a, b), dim=-1) version.

Results of running the repro script (difference versus reference results):

PASSING: a1
        Max abs diff: 2.899e-04
        Max rel diff: 1.924e-06
PASSING: a2
        Max abs diff: 2.899e-04
        Max rel diff: 1.924e-06
FAILING: a
        Max abs diff: 3.637e+02
        Max rel diff: 1.000e+00
FAILING: b
        Max abs diff: 0.000e+00
        Max rel diff: 0.000e+00

Concatting a with itself or with b should not be changing the value of a!

All required files are in the provided repro_files.zip.

Steps to reproduce your issue

python repro.py

What component(s) does this issue relate to?

Compiler

Version information

iree-3.5.0rc20250610

Additional context

No response

Alex-Vasile avatar Jun 12 '25 18:06 Alex-Vasile

@MaheshRavishankar I suspect that this is related to the dispatch that is now failing to compile after https://github.com/iree-org/iree/pull/21063. Especially since the fix was meant to resolve 2 very similar correctness issues.

IanWood1 avatar Jun 13 '25 17:06 IanWood1

Are these going down tile and fuse pipeline?

MaheshRavishankar avatar Jun 13 '25 17:06 MaheshRavishankar

Pre-patch they are going down warp reduction (possible miscompile) and after the patch they are going down tile and fuse (compilation fail)

IanWood1 avatar Jun 13 '25 18:06 IanWood1

@qedawkins could you look into this please (I am assuming you arent deep in some context as of yet, so fixing this before that would hopefully be easy?)

MaheshRavishankar avatar Jun 13 '25 19:06 MaheshRavishankar

yes, I can take a look

qedawkins avatar Jun 13 '25 20:06 qedawkins

The issue is another case of consumers using two results

%18:2 = scf.forall
%19 = linalg.generic {
  indexing_maps = [
    affine_map<(d0, d1, d2, d3) -> (d3)>,
    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
  iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
  ins(%13, %18#1, %18#0 : tensor<32xf32>, tensor<3x?x32x32xf32>, tensor<3x?x32xf32>)

Failing to fuse during workgroup distribution. This is also with one of the inputs broadcasted but the broadcasted dimension isn't distributed [to workgroups] so that part's fine. The easiest fix will be to move the workgroup tile sizes to the consumer. The other option is to support these multi-use consumer cases.

qedawkins avatar Jun 16 '25 03:06 qedawkins

https://github.com/llvm/llvm-project/pull/145193 should fix the core issue. I need to try it in IREE next.

MaheshRavishankar avatar Jun 21 '25 23:06 MaheshRavishankar

@Alex-Vasile can you check if https://github.com/iree-org/iree/pull/21171 fixes the issue.

MaheshRavishankar avatar Jun 24 '25 06:06 MaheshRavishankar

@Alex-Vasile can you check if #21171 fixes the issue.

Seem to have fixed it!

PASSING: a1
        Max abs diff: 2.899e-04
        Max rel diff: 1.924e-06
PASSING: a2
        Max abs diff: 2.899e-04
        Max rel diff: 1.924e-06
FAILING: a
        Max abs diff: 2.899e-04
        Max rel diff: 1.924e-06
FAILING: b
        Max abs diff: 0.000e+00
        Max rel diff: 0.000e+00

Alex-Vasile avatar Jun 24 '25 14:06 Alex-Vasile