lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Errors on in-place ops on tensor aliases unresolved by proxy substitution

Open shino16 opened this issue 1 month ago • 6 comments

Repro:

import torch, thunder

def f(a, b):
    return a.exp_() * b.tanh_()

def g(a, _):
    b = a.view(5,5)
    return a.exp_() * b.tanh_()

def h(a, _):
    b = a[0,0]
    return a.exp_() * b.tanh_()

for fn in [f, g, h]:
    jf = thunder.jit(fn)
    x = torch.randn(5, 5, device='cuda')
    x_ = x.detach().clone()
    out = jf(x, x[0, 0])
    out_ = fn(x_, x_[0, 0])

    torch.testing.assert_close(out, out_)
    # AssertionError on f, g and h

Found in https://github.com/Lightning-AI/lightning-thunder/pull/2760#issuecomment-3562836815 and https://github.com/Lightning-AI/lightning-thunder/pull/2760#issuecomment-3562688372 by @beverlylytle.

Trace of f after update_aliases.py:

# Constructed by Update aliases for in-place ops
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[5, 5]"
  # b: "cuda:0 f32[]"
  (t6,) = prims.update_aliases((a,))

  # /opt/pytorch/lightning-thunder/tmp/main.py:4: 	    return a.exp_() * b.tanh_()
  t1 = ltorch.exp_(t6)  # t1: "cuda:0 f32[5, 5]"
    # t0 = ltorch.exp(t6)  # t0: "cuda:0 f32[5, 5]"
      # t0 = prims.exp(t6)  # t0: "cuda:0 f32[5, 5]"
    # t1 = prims.copy_(t0, t6, grad_enabled=True)  # t1: "cuda:0 f32[5, 5]"
  (t7,) = prims.update_aliases((b,))

  # /opt/pytorch/lightning-thunder/tmp/main.py:4: 	    return a.exp_() * b.tanh_()
  t3 = ltorch.tanh_(t7)  # t3: "cuda:0 f32[]"
    # t2 = ltorch.tanh(t7)  # t2: "cuda:0 f32[]"
      # t2 = prims.tanh(t7)  # t2: "cuda:0 f32[]"
    # t3 = prims.copy_(t2, t7, grad_enabled=True)  # t3: "cuda:0 f32[]"
  t5 = ltorch.mul(t1, t3)  # t5: "cuda:0 f32[5, 5]"
    # t4 = prims.broadcast_in_dim(t3, (5, 5), ())  # t4: "cuda:0 f32[5, 5]"
    # t5 = prims.mul(t1, t4)  # t5: "cuda:0 f32[5, 5]"
  return {'output': (t5,), 'flat_args': [t1, t3]}

Trace after fusion:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[5, 5]"
  # b: "cuda:0 f32[]"
  (t10,) = update_aliases((a,))
  del a
  [t1] = nvFusion0(t10)
    # t0 = prims.exp(t10)  # t0: "cuda:0 f32[5, 5]"
    # t1 = prims.copy_(t0, t10, grad_enabled=True)  # t1: "cuda:0 f32[5, 5]"
  del t10
  (t11,) = update_aliases((b,))
  del b
  [t3, t5] = nvFusion1(t11, t1)
    # t2 = prims.tanh(t11)  # t2: "cuda:0 f32[]"
    # t3 = prims.copy_(t2, t11, grad_enabled=True)  # t3: "cuda:0 f32[]"
    # t4 = prims.broadcast_in_dim(t3, (5, 5), ())  # t4: "cuda:0 f32[5, 5]"
    # t5 = prims.mul(t1, t4)  # t5: "cuda:0 f32[5, 5]"
  del t11
  return {'output': (t5,), 'flat_args': [t1, t3]}

The problem here is that nvFusion1 does not know that t11 and t1 share memory.

In order to make sure that t3 = prims.copy_(t2, t11, grad_enabled=True) completes before t5 = prims.mul(t1, t4), we could insert prims.update_aliases before prims.mul, which would fix the bugs because prims.update_aliases is unfusible.

Such solutions create more fusion breaks, so we want to minimize the use of prims.update_aliases. Ideally, we hope to make prims.update_aliases a fusible op and let nvFuser handle memory aliases in its combined region.

shino16 avatar Nov 24 '25 15:11 shino16

This is an example where inserting prims.update_aliases seems more appropriate:

import torch, thunder

def f(a, b):
    t2 = a.exp_()
    b.tanh_()
    return t2.sin()

jf = thunder.jit(f)
x = torch.randn(5, 5, device='cuda')
x_ = x.detach().clone()
out = jf(x, x[0, 0])
out_ = f(x_, x_[0, 0])

torch.testing.assert_close(out, out_)
# AssertionError on f

Trace before update_aliases

# Constructed by Remove context manager prims
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[5, 5]"
  # b: "cuda:0 f32[]"

  # /opt/pytorch/lightning-thunder/tmp/main.py:4: 	    t2 = a.exp_()
  t2 = ltorch.exp_(a)  # t2: "cuda:0 f32[5, 5]"
    # t0 = ltorch.exp(a)  # t0: "cuda:0 f32[5, 5]"
      # t0 = prims.exp(a)  # t0: "cuda:0 f32[5, 5]"
    # t2 = prims.copy_(t0, a, grad_enabled=True)  # t2: "cuda:0 f32[5, 5]"

  # /opt/pytorch/lightning-thunder/tmp/main.py:5: 	    b.tanh_()
  t4 = ltorch.tanh_(b)  # t4: "cuda:0 f32[]"
    # t3 = ltorch.tanh(b)  # t3: "cuda:0 f32[]"
      # t3 = prims.tanh(b)  # t3: "cuda:0 f32[]"
    # t4 = prims.copy_(t3, b, grad_enabled=True)  # t4: "cuda:0 f32[]"

  # /opt/pytorch/lightning-thunder/tmp/main.py:6: 	    return t2.sin()
  t5 = ltorch.sin(t2)  # t5: "cuda:0 f32[5, 5]"
    # t5 = prims.sin(t2)  # t5: "cuda:0 f32[5, 5]"
  return {'output': (t5,), 'flat_args': [a, b]}

t2 and t4 shares memory, so t4 = ltorch.tanh_(b) must happen before t5 = ltorch.sin(t2). This justifies inserting prims.update_alises((t2, t4)).

Executed trace
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[5, 5]"
  # b: "cuda:0 f32[]"
  (t10,) = update_aliases((a,))
  del a
  [t2] = nvFusion0(t10)
    # t0 = prims.exp(t10)  # t0: "cuda:0 f32[5, 5]"
    # t2 = prims.copy_(t0, t10, grad_enabled=True)  # t2: "cuda:0 f32[5, 5]"
  del t10
  (t11,) = update_aliases((b,))
  del b
  [t4, t5] = nvFusion1(t11, t2)
    # t3 = prims.tanh(t11)  # t3: "cuda:0 f32[]"
    # t4 = prims.copy_(t3, t11, grad_enabled=True)  # t4: "cuda:0 f32[]"
    # t5 = prims.sin(t2)  # t5: "cuda:0 f32[5, 5]"
  del t11
  return (t5,)
Fusion Definitions
# nvFusion0
def nvfuser_fusion(fd : FusionDefinition) -> None :
    tv0 = fd.define_tensor(shape=[5, 5], contiguity=[True, True], dtype=DataType.Float, is_cpu=False)
    tv1 = fd.ops.exp(tv0)
    tv2 = fd.ops.cast(tv1, dtype=DataType.Float)
    fd.add_output(tv2, tv0)
    fd.add_output(tv2)

# nvFusion1
def nvfuser_fusion(fd : FusionDefinition) -> None :
    tv0 = fd.define_tensor(shape=[], contiguity=[], dtype=DataType.Float, is_cpu=False)
    tv1 = fd.define_tensor(shape=[5, 5], contiguity=[True, True], dtype=DataType.Float, is_cpu=False)
    tv2 = fd.ops.tanh(tv0)
    tv3 = fd.ops.cast(tv2, dtype=DataType.Float)
    tv4 = fd.ops.sin(tv1)
    fd.add_output(tv3, tv0)
    fd.add_output(tv3)
    fd.add_output(tv4)

shino16 avatar Nov 24 '25 21:11 shino16

diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py
index 3a679bcf..a37cc40d 100644
--- a/thunder/core/update_aliases.py
+++ b/thunder/core/update_aliases.py
@@ -150,7 +150,10 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
             out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs)))
             if _is_inplace_op(bsym):
                 inplace_inputs.add(in_tensor)
-                out_tensors = set()
+                if len(out_tensors) == 1 and all(len(group) <= 1 or in_tensor not in group for group in view_groups):
+                    out_tensors = set()
             for group in view_groups:
                 if in_tensor in group:
                     group.update(out_tensors)

This is a small patch on #2760 in my mind that works for all the repros here. And I think this is natural and reasonable to do.

To sort things out, we can eliminate out_tensors being aliased to in_tensor if and only if

  • in_tensor is going to be swapped with out_tensor by _update_swap_map in subsequent bsyms, and
  • in_tensor has no other aliases so far.

These were the hidden assumptions behind update_aliases.py, and the repros here are their counterexamples. The if in the patch above tests if these assumptions are met.

We eliminate input aliases by replacing one with another in replace_args_with_alias_map, and we eliminate aliases created by in-place ops by replacing in_tensor with out_tensor in subsequent bsyms using swap_map. But when view-creating ops create aliases, they are unresolvable, so we handle it using view_groups and prims.update_aliases. In #2760, the same treatment is applied when we come across unresolvable input aliases. So, when aliases created by in-place ops are unresolvable, it makes sense to put them into view_groups too.

We do aim at causing fewer fusion breaks, but in the repros here, avoiding fusion break seems as difficult as other unresolvable aliases. View-creating ops, unresolvable input aliases and unresolvable in-place ops, all seem to be on the same line.

Note that the patch above does not introduce unnecessary fusion breaks on in-place ops that we have considered, which satisfy the assumptions above.

import torch, thunder

def fn(a, b):
    return a.exp_() * b.tanh_()

jf = thunder.jit(fn)
x = torch.randn(5, 5, device='cuda')
y = torch.randn(5, 5, device='cuda')
x_ = x.detach().clone()
y_ = y.detach().clone()
out = jf(x, y)
out_ = fn(x_, y_)

torch.testing.assert_close((x, y, out), (x_, y_, out_))
print(thunder.last_traces(jf)[-1])
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[5, 5]"
  # b: "cuda:0 f32[5, 5]"
  (t9,) = update_aliases((a,))
  del a
  [t1] = nvFusion0(t9)
    # t0 = prims.exp(t9)  # t0: "cuda:0 f32[5, 5]"
    # t1 = prims.copy_(t0, t9, grad_enabled=True)  # t1: "cuda:0 f32[5, 5]"
  del t9
  (t10,) = update_aliases((b,))
  del b
  [t3, t4] = nvFusion1(t10, t1)
    # t2 = prims.tanh(t10)  # t2: "cuda:0 f32[5, 5]"
    # t3 = prims.copy_(t2, t10, grad_enabled=True)  # t3: "cuda:0 f32[5, 5]"
    # t4 = prims.mul(t1, t3)  # t4: "cuda:0 f32[5, 5]"
  del t10
  return (t4,)

shino16 avatar Nov 25 '25 01:11 shino16

I'm a little confused by the fn defined in the last comment because there are no views of any tensors, and so the view groups created are uncomplicated, and the effect of the patch isn't obvious.

Consider instead something we looked at in another issue

def f(a, _):
    b = a[0,0]
    c = a.exp_()
    b.tanh_()
    return c.sin()

With the above patch we create the following trace

import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f32[5, 5]"
  (t73,) = update_aliases((a,))
  del a

  # /home/blytle/scratch/sym_mod.py:52:             b = a[0,0]
  t83 = torch_slice_prim_impl(t73, [0, 0], [1, 1], [1, 1])  # t83: "cuda:0 f32[1, 1]"
  t84 = torch.squeeze(t83, (0, 1))  # t84: "cuda:0 f32[]"
    # t84 = ltorch.squeeze(t83, (0, 1))  # t84: "cuda:0 f32[]"
      # t84 = prims.squeeze(t83, (0, 1))  # t84: "cuda:0 f32[]"
  del t83
  (t74, t75) = update_aliases((t73, t84))
  del t73, t84
  [t32] = nvFusion0(t74)
    # t29 = prims.exp(t74)  # t29: "cuda:0 f32[5, 5]"
    # t32 = prims.copy_(t29, t74, grad_enabled=True)  # t32: "cuda:0 f32[5, 5]"
  del t74
  (t76, t77, t78) = update_aliases((t32, t75, t32))
  del t76, t32, t75
  [t45] = nvFusion1(t77)
    # t40 = prims.tanh(t77)  # t40: "cuda:0 f32[]"
    # t45 = prims.copy_(t40, t77, grad_enabled=True)  # t45: "cuda:0 f32[]"
  del t77
  (t79, t80, t81, t82) = update_aliases((t78, t45, t78, t45))
  del t79, t80, t82, t78, t45
  [t55] = nvFusion2(t81)
    # t55 = prims.sin(t81)  # t55: "cuda:0 f32[5, 5]"
  return (t55,)

Note (t76, t77, t78) = update_aliases((t32, t75, t32)) updates t32 twice, and again it updates twice in the next update_aliases. This makes the trace very confusing to read and reason about, and I believe it could potentially result in some unexpected behavior. This is coming from the fact that this patch adds both a and an alias of a to the view groups. I guess one could filter out the repeated symbols in the creation of the update_aliases bsym, but it feels inefficient and wrong to add something to view groups and then filter it out.

Moreover, this patch does result in excess instances of update_aliases, very weird ones. The function

def f(a, _):
    b = a.view(5,5)
    c = a.exp_()
    b.tanh_()
    return c.sin() * 2 * b

produces the trace

import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f32[5, 5]"
  (t97,) = update_aliases((a,))
  del a

  # /home/blytle/scratch/sym_mod.py:52:             b = a.view(5,5)
  t111 = shallow_copy(t97)  # t111: "cuda:0 f32[5, 5]"
  (t98, t99) = update_aliases((t111, t97))
  del t111, t97
  [t33] = nvFusion0(t99)
    # t30 = prims.exp(t99)  # t30: "cuda:0 f32[5, 5]"
    # t33 = prims.copy_(t30, t99, grad_enabled=True)  # t33: "cuda:0 f32[5, 5]"
  del t99
  (t100, t101, t102) = update_aliases((t33, t98, t33))
  del t100, t33, t98
  [t46] = nvFusion1(t101)
    # t41 = prims.tanh(t101)  # t41: "cuda:0 f32[5, 5]"
    # t46 = prims.copy_(t41, t101, grad_enabled=True)  # t46: "cuda:0 f32[5, 5]"
  del t101
  (t103, t104, t105, t106) = update_aliases((t102, t46, t46, t102))
  del t103, t104, t102, t46
  (t107, t108, t109, t110) = update_aliases((t106, t105, t105, t106))
  del t107, t108, t105
  [t72] = nvFusion3(t106, t109)
    # t56 = prims.sin(t106)  # t56: "cuda:0 f32[5, 5]"
    # t60 = prims.mul(t56, 2.0)  # t60: "cuda:0 f32[5, 5]"
    # t72 = prims.mul(t60, t109)  # t72: "cuda:0 f32[5, 5]"
  del t106, t109
  return (t72,)

beverlylytle avatar Nov 25 '25 09:11 beverlylytle

For fn in my last comment, please ignore it because I just wanted to clarify that no additional update_aliases is introduced in such cases.

And indeed, having one tensor appearing in update_aliases((...)) is pretty awkward (although I haven't come up with a problematic case). To prevent it, we would need some machinery to update the alias relationship from {a, b} to {c, b}, instead of just adding c to {a, b}, or take care of it in _get_update_bsym, which seems simpler.

For the last case with consecutive update_aliases, I think there is a yet another separate issue, so I wrote #2768. Without my patch, we can reproduce the same behavior by replacing c = a.exp_() with c = a.view(5, 5),

def fn(a, _):
    b = a.view(5,5)
    c = a.view(5,5)
    b.tanh_()
    return c.sin() * 2 * b
# ...

# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f32[5, 5]"
  (t7,) = prims.update_aliases((a,))

  # /opt/pytorch/lightning-thunder/tmp/main.py:10: 	    b = a.view(5,5)
  b = ltorch.view(t7, 5, 5)  # b: "cuda:0 f32[5, 5]"
    # b = ltorch.reshape(t7, (5, 5))  # b: "cuda:0 f32[5, 5]"
      # b = prims.shallow_copy(t7)  # b: "cuda:0 f32[5, 5]"
  (t8, t9) = prims.update_aliases((t7, b))

  # /opt/pytorch/lightning-thunder/tmp/main.py:11: 	    c = a.view(5,5)
  c = ltorch.view(t8, 5, 5)  # c: "cuda:0 f32[5, 5]"
    # c = ltorch.reshape(t8, (5, 5))  # c: "cuda:0 f32[5, 5]"
      # c = prims.shallow_copy(t8)  # c: "cuda:0 f32[5, 5]"
  (t10, t11, t12) = prims.update_aliases((t8, c, t9))

  # /opt/pytorch/lightning-thunder/tmp/main.py:12: 	    b.tanh_()
  t3 = ltorch.tanh_(t12)  # t3: "cuda:0 f32[5, 5]"
    # t2 = ltorch.tanh(t12)  # t2: "cuda:0 f32[5, 5]"
      # t2 = prims.tanh(t12)  # t2: "cuda:0 f32[5, 5]"
    # t3 = prims.copy_(t2, t12, grad_enabled=True)  # t3: "cuda:0 f32[5, 5]"
  (t13, t14, t15) = prims.update_aliases((t10, t11, t3))

  # /opt/pytorch/lightning-thunder/tmp/main.py:13: 	    return c.sin() * 2 * b
  t4 = ltorch.sin(t14)  # t4: "cuda:0 f32[5, 5]"
    # t4 = prims.sin(t14)  # t4: "cuda:0 f32[5, 5]"
  t5 = ltorch.mul(t4, 2)  # t5: "cuda:0 f32[5, 5]"
    # t5 = prims.mul(t4, 2.0)  # t5: "cuda:0 f32[5, 5]"
  (t16, t17, t18) = prims.update_aliases((t13, t14, t15))

  # /opt/pytorch/lightning-thunder/tmp/main.py:13: 	    return c.sin() * 2 * b
  t6 = ltorch.mul(t5, t18)  # t6: "cuda:0 f32[5, 5]"
    # t6 = prims.mul(t5, t18)  # t6: "cuda:0 f32[5, 5]"
  return {'output': (t6,), 'flat_args': [t16]}

# ...

# Constructed by Unwrap the actual return value
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f32[5, 5]"
  (t31,) = update_aliases((a,))
  del a

  # /opt/pytorch/lightning-thunder/tmp/main.py:10: 	    b = a.view(5,5)
  t43 = Tensor.view(t31, 5, 5)  # t43: "cuda:0 f32[5, 5]"
    # t43 = ltorch.view(t31, 5, 5)  # t43: "cuda:0 f32[5, 5]"
      # t43 = ltorch.reshape(t31, (5, 5))  # t43: "cuda:0 f32[5, 5]"
        # t43 = prims.shallow_copy(t31)  # t43: "cuda:0 f32[5, 5]"
  (t32, t33) = update_aliases((t31, t43))
  del t31, t43

  # /opt/pytorch/lightning-thunder/tmp/main.py:11: 	    c = a.view(5,5)
  t44 = Tensor.view(t32, 5, 5)  # t44: "cuda:0 f32[5, 5]"
    # t44 = ltorch.view(t32, 5, 5)  # t44: "cuda:0 f32[5, 5]"
      # t44 = ltorch.reshape(t32, (5, 5))  # t44: "cuda:0 f32[5, 5]"
        # t44 = prims.shallow_copy(t32)  # t44: "cuda:0 f32[5, 5]"
  (t34, t35, t36) = update_aliases((t32, t44, t33))
  del t32, t44, t33
  [t3] = nvFusion0(t36)
    # t2 = prims.tanh(t36)  # t2: "cuda:0 f32[5, 5]"
    # t3 = prims.copy_(t2, t36, grad_enabled=True)  # t3: "cuda:0 f32[5, 5]"
  del t36
  (t37, t38, t39) = update_aliases((t34, t35, t3))
  del t34, t35, t3
  (t40, t41, t42) = update_aliases((t37, t38, t39))
  del t41, t37, t39
  [t6] = nvFusion2(t38, t42)
    # t4 = prims.sin(t38)  # t4: "cuda:0 f32[5, 5]"
    # t5 = prims.mul(t4, 2.0)  # t5: "cuda:0 f32[5, 5]"
    # t6 = prims.mul(t5, t42)  # t6: "cuda:0 f32[5, 5]"
  del t38, t42
  return (t6,)

And this is part of #2768. The view group {a, c, b} was already established, so prims.update_aliases was inserted before both multiplications.

shino16 avatar Nov 25 '25 16:11 shino16

Thanks for the detailed write-up @shino16, and for tracing this down with @beverlylytle!

Masato's approach is conceptually consistent. As Masato notes, this treats unresolvable in-place op aliases the same way we treat unresolvable view-creating ops and input aliases. They're all in the same category of "aliases we can't eliminate through substitution". The trade-off is that we accept fusion breaks for the problematic cases (which seems unavoidable without nvFuser-level aliasing support), while preserving fusion for the common case where tensors don't share memory. This looks good to me as a fix on top of #2760, even though traces become quite verbose in the number of "update_aliases". However, let's revisit what we're trying to solve here.

The core problem is clear from the traces: as you say, nvFuser doesn't know that t11 and t1 share memory when they're passed as separate inputs to nvFusion1 (and Thunder-jitted function). Looking at the trace:

[t1] = nvFusion0(t10)
  # t0 = prims.exp(t10)
  # t1 = prims.copy_(t0, t10, grad_enabled=True)
...
[t3, t5] = nvFusion1(t11, t1)
  # t2 = prims.tanh(t11)
  # t3 = prims.copy_(t2, t11, grad_enabled=True)
  # t5 = prims.mul(t1, t4)

When nvFuser executes prims.copy_ to t11, it should invalidate t1, but it has no way to know they alias each other. This should be fixed, but how?

On alias detection at trace initialization

Repeating some of the information from "Alias Detection" from the design doc. The question of whether input tensors should be grouped as aliases at trace initialization is worth investigating. PyTorch provides Tensor.data_ptr() and Tensor.storage_offset() for this. However, as the NumPy mem_overlap.c implementation shows, detecting whether two strided arrays actually share memory elements (not just storage) is NP-hard in the general case.

For Thunder's purposes, I think we should consider two approaches:

  1. Conservative (may_share_memory style): Check if memory bounds overlap using storage_offset and extent calculations. This is O(ndim) and gives us a safe over-approximation - some non-overlapping tensors might be grouped together, but we won't miss actual aliases. This is what NumPy's may_share_memory does with max_work=0. This is better than conservatively assuming all input tensors may share memory.

  2. Exact (shares_memory style): Solve the Diophantine equation to determine actual overlap. This is expensive and only needed if we want to avoid unnecessary fusion breaks for non-overlapping slices of the same storage.

For correctness, option 1 is sufficient and cheap. We could detect at prologue time which input tensors may share memory, then:

  • Group them together so the aliasing relationship is known throughout the trace
  • Either insert update_aliases barriers before operations that read from potentially-modified aliases, or
  • Pass this grouping information to nvFuser so it can handle sequencing internally

The reshape replacement pattern works when tensors have the same numel, but for cases like x[0,0], we need a different mechanism. One option: could we pass aliasing metadata to FusionDefinitionWrapper, similar to how we pass shape/stride info? Then, nvFuser could at least sequence operations correctly within a fusion.

For the immediate fix, detecting may-share-memory at input processing time and inserting update_aliases before the mul would work (accepting the fusion break). The three test cases (f, g, h) cover the important scenarios and should be added as regression tests.

@beverlylytle and @shino16, could you please check separately with @jjsjann123 on what it would take for nvFuser to handle aliased inputs within a fusion? If Thunder can provide all alias groups, would nvFuser be able to use them?

IvanYashchuk avatar Nov 25 '25 17:11 IvanYashchuk

I submit #2769 for your consideration as an alternative WIP to the patch suggested in the comment above.

#2768 is very interesting to me. I'm curious how to resolve that.

I am interested to see how the discussion with the nvFuser team resolves for situations where either the above patch or #2769 does not resolve the aliased-input to a fusion problem.

beverlylytle avatar Nov 25 '25 23:11 beverlylytle

nvFuser segmentation behaviour could increase the difficulty of a possible update_aliases within fusion regions.

For example, the test function:

def f(a, b):
    result = a.exp_() * b.tanh_()
    a.add_(1)
    return result, a

shape = (2, 2)
a = torch.ones(shape, ...)
b = a[0, 0]

Thunder checked out to #2760 generates:

def computation(a, b):
  # a: "cuda:0 f32[5, 5]"
  # b: "cuda:0 f32[]"
  (t20, t21) = update_aliases((b, a))
  del b, a
  [t1] = nvFusion0(t21)
    # t0 = prims.exp(t21)  # t0: "cuda:0 f32[5, 5]"
    # t1 = prims.copy_(t0, t21, grad_enabled=True)  # t1: "cuda:0 f32[5, 5]"
  del t21
  (t22, t23) = update_aliases((t20, t1))
  del t20, t1
  [t3, result] = nvFusion1(t22, t23)
    # t2 = prims.tanh(t22)  # t2: "cuda:0 f32[]"
    # t3 = prims.copy_(t2, t22, grad_enabled=True)  # t3: "cuda:0 f32[]"
    # t4 = prims.broadcast_in_dim(t3, (5, 5), ())  # t4: "cuda:0 f32[5, 5]"
    # result = prims.mul(t23, t4)  # result: "cuda:0 f32[5, 5]"
  del t22
  (t24, t25) = update_aliases((t3, t23))
  del t3, t23
  [t7] = nvFusion2(t25)
    # t6 = prims.add(t25, 1.0)  # t6: "cuda:0 f32[5, 5]"
    # t7 = prims.copy_(t6, t25, grad_enabled=True)  # t7: "cuda:0 f32[5, 5]"
  del t25
  return (result, t7)

Now, if we focus on the second fusion region, which generates the wrong output, the FusionDefinition looks like:

def nvfuser_fusion(fd : FusionDefinition) -> None :
    tv0 = fd.define_tensor(shape=[], contiguity=[], dtype=DataType.Float, is_cpu=False)
    tv1 = fd.define_tensor(shape=[5, 5], contiguity=[True, True], dtype=DataType.Float, is_cpu=False)
    tv2 = fd.ops.tanh(tv0)
    tv3 = fd.ops.cast(tv2, dtype=DataType.Float)
    tv4 = fd.ops.broadcast(tv3, is_broadcast_dim=[True, True])
    tv5 = fd.ops.expand(tv4, shape=[5, 5])
    tv6 = fd.ops.mul(tv1, tv5)
    fd.add_output(tv3, tv0)
    fd.add_output(tv3)
    fd.add_output(tv6)

Due to its SSA structure, it's clear now that it doesn't know that tv1 is invalidated after tanh.

If we analyze the generated kernels, we see how the in-place copy_ writeback is deferred as the last executed kernel of this fusion definition; moreover, it's handled separately from the main computation kernel:

// Codegen generated code
__global__ void nvfuser_no_op_f0_c1_r0_g0(Tensor<float, 0, 0> T8, Tensor<float, 0, 0> T0, Tensor<float, 0, 0> T9) {
  T9[0]
     = T8[0];
}

// Codegen generated code
__global__ void nvfuser_no_op_f0_c1_r0_g1(Tensor<float, 0, 0> T0, Tensor<float, 0, 0> T8) {
  Array<float, 1, 1> T2;
  T2[0] = 0;
  T2[0]
     = tanhf(T0[0]);
  Array<float, 1, 1> T3;
  T3[0]
     = T2[0];
  T8[0]
     = T3[0];
}

// Codegen generated code
__global__ void nvfuser_pointwise_f0_c1_r0_g2(Tensor<float, 0, 0> T0, Tensor<float, 2, 2> T1, Tensor<float, 2, 2> T6) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
  Array<float, 1, 1> T2;
  T2[0] = 0;
  T2[0]
     = tanhf(T0[0]);
  Array<float, 1, 1> T7;
  T7[0]
     = T2[0];
  Array<float, 1, 1> T5;
  T5[0]
     = T7[0];
  if ((i0 < 25)) {
    Array<float, 1, 1> T10;
    T10[0] = 0;
    T10[0]
       = T1[i0];
    Array<float, 1, 1> T11;
    T11[0]
      = T10[0]
      * T5[0];
    T6[i0]
       = T11[0];
  }
}

From nsys is clear how the value of the in-place copy_ is detached from whatever happens in xxxx_g2:

Image

This results in a race condition if update_aliases can be fused, merging fusion1 and fusion2 (potentially also fusion0) because the writeback by xxx_g0 uses a value precomputed separately from whatever happens in xxx_g2.

mattteochen avatar Nov 27 '25 15:11 mattteochen

@mattteochen What you write makes it clear that not all instances of update_aliases can be fused. However, there are some instances where what you point out is not a problem.

def f(x):
     y = x.view(x)
     <fusible, out-of-place op with y as input>
     <another fusible, out-of-place op with y as input>
     <yet another fusible, out-of-place op with y as input>
     <in-place op on y>
     ....

update_aliases would be inserted before each of the fusible, out-of-place ops on y. These would trigger fusion breaks which are not protecting against the race conditions you mention. It is only really important that the update_aliases prevents the in-place op on y from being reordered relative to the out-of-place ones.

beverlylytle avatar Dec 01 '25 10:12 beverlylytle

Hi @beverlylytle , is there any work on https://github.com/Lightning-AI/lightning-thunder/pull/2769? I found a regression in PR https://github.com/Lightning-AI/lightning-thunder/pull/2772 that causes this issue in more cases (see https://github.com/Lightning-AI/lightning-thunder/pull/2772#issuecomment-3617793780).

To me, https://github.com/Lightning-AI/lightning-thunder/pull/2772 seems like going in the opposite direction of https://github.com/Lightning-AI/lightning-thunder/pull/2769. Indeed, while #2769 fixed the regression, it caused the new tests in #2772 to fail.

FAILED thunder/tests/test_update_aliases.py::test_update_aliases_count_torch_cuda_thunder.dtypes.float32 - assert 4 == 1
FAILED thunder/tests/test_update_aliases.py::test_update_aliases_count_nvfuser_cuda_thunder.dtypes.float32 - assert 4 == 1
FAILED thunder/tests/test_update_aliases.py::test_update_aliases_count_torch_cpu_thunder.dtypes.float32 - assert 4 == 1

shino16 avatar Dec 05 '25 17:12 shino16

@shino16 Right, I was trying to point out that these two PRs are conceptually in conflict, and that we needed to be careful about how we combine them. You said wanted to merge #2772 first. So I paused work on #2769 to see how #2772 would be finalized. Looking now at the most recent version of #2772 I see that view_groups is now being updated to contain the most recently created alias for any given tensor rather than what happens in #2769 where view_groups is statically set to the original names of the tensors, and all things compared to view_groups are mapped back to their original names. Either approach can be used. But what is not a good idea (because it will require cleaning up in several places) is including two aliases for the same tensor in a view group.

beverlylytle avatar Dec 05 '25 19:12 beverlylytle

Thank you for your consideration! I misunderstood your comment on https://github.com/Lightning-AI/lightning-thunder/pull/2772#pullrequestreview-3534603422 as "possible conflict with (some other PR)". I'm so sorry, and thank you for pausing your work meanwhile. There's so many PRs in the air...

Either approach can be used.

This makes sense! I've figured that the difference is whether we update view_groups forward or rewinding bsym inputs backward. The difference should be only at the level of implementation details.

Consider the following repro of #2768

import torch, thunder

def print_simplified_trace(trace):
    for bsym in trace.bound_symbols:
        bsym.source_filename = None
        bsym.subsymbols.clear()
    print(trace.python(include_decorators=False))

def fn(a):
    a.tanh_()
    return a * a * a * a

jfn = thunder.jit(fn, fusion_type="dataflow")
out = jfn(torch.randn(5, 5, device='cuda'))
for trace in thunder.last_traces(jfn):
    if (pr := trace.get_provenance()) and pr.pss == "Update aliases for in-place ops":
        print_simplified_trace(prv_trace)
        print_simplified_trace(trace)
    prv_trace = trace

torch.testing.assert_close(out, fn(torch.randn(5, 5, device='cuda')))
def computation(a):
  t1 = ltorch.tanh_(a)  # t1: "cuda:0 f32[5, 5]"
  t2 = ltorch.mul(a, a)  # t2: "cuda:0 f32[5, 5]"
  t3 = ltorch.mul(t2, a)  # t3: "cuda:0 f32[5, 5]"
  t4 = ltorch.mul(t3, a)  # t4: "cuda:0 f32[5, 5]"
  return {'output': (t4,), 'flat_args': [a]}

view_groupis {a}. On main, it checks _involves_viewed_args(ltorch.mul(..., a)), which is always true.

def computation(a):
  (t5,) = prims.update_aliases((a,))
  t1 = ltorch.tanh_(t5)  # t1: "cuda:0 f32[5, 5]"
  (t6,) = prims.update_aliases((t1,))
  t2 = ltorch.mul(t6, t6)  # t2: "cuda:0 f32[5, 5]"
  (t7,) = prims.update_aliases((t6,))
  t3 = ltorch.mul(t2, t7)  # t3: "cuda:0 f32[5, 5]"
  (t8,) = prims.update_aliases((t7,))
  t4 = ltorch.mul(t3, t8)  # t4: "cuda:0 f32[5, 5]"
  return {'output': (t4,), 'flat_args': [t8]}

On #2772, it checks _involves_viewed_args(ltorch.mul(..., t1)), which is always false.

def computation(a):
  (t5,) = prims.update_aliases((a,))
  t1 = ltorch.tanh_(t5)  # t1: "cuda:0 f32[5, 5]"
  t2 = ltorch.mul(t1, t1)  # t2: "cuda:0 f32[5, 5]"
  t3 = ltorch.mul(t2, t1)  # t3: "cuda:0 f32[5, 5]"
  t4 = ltorch.mul(t3, t1)  # t4: "cuda:0 f32[5, 5]"
  return {'output': (t4,), 'flat_args': [t1]}

In merging #2769, we probably don't want to unswap t1 back to a here. Whether we go forward or backward, we have to be mindful of which renaming should (and should not) be applied. (And this is the reason for this comment https://github.com/Lightning-AI/lightning-thunder/pull/2772/commits/a0fbd456a351b8dc5973f5af24819c456bae651c#r2593737171)

shino16 avatar Dec 05 '25 20:12 shino16

I personally prefer the "backward approach", because updating view_groups forward for every update_aliases seems like a big deal. Plus, now we have found https://github.com/Lightning-AI/lightning-thunder/pull/2772 contains some regression, we should start on the safe side.

@beverlylytle Would you mind continuing your work on https://github.com/Lightning-AI/lightning-thunder/pull/2769? We can work on the issue https://github.com/Lightning-AI/lightning-thunder/issues/2768 afterwards, which would involve controlling the applicability of #2769. What I have in my mind is to maintain a backward swap_map with controlled mappings instead of reverting all of swap_map.

shino16 avatar Dec 05 '25 21:12 shino16

I've updated #2769 which fixes the conceptual error of trying to intersect elements of view_groups with the potentially aliased args of a given bsym. I have not added anything further like a limited version of swap_map.

beverlylytle avatar Dec 08 '25 14:12 beverlylytle