lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Desired inplace idioms (add yours here)

Open t-vi opened this issue 1 year ago • 12 comments

We do support some inplace thanks to @crcrpar 's great patches. This has enable running additional models already, but is limited to a few cases of particular interest.

However, we would want to see and prioritize other important inplace ideoms, so we would appreciate if you chimed in with needs you or your favourite model has.

Please do try to post

  • a minimal code snippet,
  • if you are aware of a particular model needing this, please also add this.

We don't want to support all corner cases of inplace (eg "does reshape produce a view or not") but we do care a lot about enabling users and models. Thank you!

P.S.: Also "don't need to support but should error" is a thing.


If you have traces as well as an idiom, then could you please paste thunder.last_traces(jitted)[1:3], that are the traces generated by functionalization?

t-vi avatar Jun 26 '24 13:06 t-vi

thunder.jit the following function with nvfuerex fails with the message below. By moving the copy for a += b to the end of a trace and replacing a += b with t0 = a + b, I expect it to work.

import torch

def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    a += b
    c = torch.exp(a)
    d = torch.tanh(b)

    e = c.view(-1)
    e.add_(d.flatten())

    d.div_(a)
    return c, d, e / 2.0
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 55, in <module>
    main()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 24, in main
    c, d, e = jit_f(a, b)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 676, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 223, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 615, in get_computation_and_inputs
    thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 86, in _inplace_copy_sanity_check
    check(copy_to_out, "output")
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 80, in check
    raise NotImplementedError(
NotImplementedError: t8 = prims.div(d, t1)  # t8: "cuda:0 f32[2, 2]" trying to use t1 (the output of 'prims.copy_') as input, which is not safe. There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.

crcrpar avatar Jun 28 '24 06:06 crcrpar

The failing loudly (and not obscurely) part is super important. Let's keep this issue to keep track of this aspect.

I would prioritize enumerating corner cases as much as we can, and creating good unhappy paths for them (with sane error messages). Then we address them and lift limitations where it makes sense. WDYT @crcrpar ?

lantiga avatar Jun 28 '24 08:06 lantiga

Optimizer's inplace updates might be of concern when designing a scheme as Adam already carries 3 copies of model state algorithmically.

kevinstephano avatar Jul 01 '24 19:07 kevinstephano

this case works, but the signature of computation(res, elem, step_t, t_1_1) does not look great given that the input g's signature is params: list[torch.Tensor], steps: list[torch.Tensor].

@thunder.jit
def g(params, steps):
    for i, param in enumerate(params):
        step_t = steps[i]

        step_t.add_(1)

        param.add_(step_t)

    return params

with torch.device('cuda'):
    params = [torch.tensor(0), torch.tensor(0)]
    steps = [torch.tensor(0), torch.tensor(0)]
g(params, steps)

print(params, steps)
print(thunder.last_traces(g)[-1])
    ([tensor(1, device='cuda:0'), tensor(1, device='cuda:0')],
     [tensor(1, device='cuda:0'), tensor(1, device='cuda:0')])

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(res, elem, step_t, t_1_1):
  # res: "cuda:0 i64[]"
  # elem: "cuda:0 i64[]"
  # step_t: "cuda:0 i64[]"
  # t_1_1: "cuda:0 i64[]"
  nvFusion0(elem, res, step_t, t_1_1)
    # t0 = prims.add(step_t, 1)  # t0: "cuda:0 i64[]"
    # t2 = prims.add(res, t0)  # t2: "cuda:0 i64[]"
    # t4 = prims.add(t_1_1, 1)  # t4: "cuda:0 i64[]"
    # t6 = prims.add(elem, t4)  # t6: "cuda:0 i64[]"
    # prims.copy_(t0, step_t)
    # prims.copy_(t2, res)
    # prims.copy_(t4, t_1_1)
    # prims.copy_(t6, elem)
  del elem, res, step_t, t_1_1
  return [_torch_Tensor_0, _torch_Tensor_1]

shino16 avatar Jul 04 '24 07:07 shino16

multiple in-place whose operand is the func's arg is not appropriately handled.

import torch

import thunder


def f(a):
    return a.exp_().sin_()


if __name__ == "__main__":
    x = torch.randn(4, device="cuda", requires_grad=False)
    x_ref = x.clone().detach().requires_grad_(False)

    y_ref = f(x_ref)
    jitted = thunder.jit(f)
    y = jitted(x)

    print(thunder.last_traces(jitted)[-1])

gives

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f32[4]"
  [t2] = nvFusion0(a)
    # t0 = prims.exp(a)  # t0: "cuda:0 f32[4]"
    # t2 = prims.sin(t0)  # t2: "cuda:0 f32[4]"
    # prims.copy_(t0, a)
  del a
  return t2

crcrpar avatar Jul 04 '24 10:07 crcrpar

To make a jittable implementation of Adam, I want this to be sound.

def _single_tensor_adam(
    params: list[Tensor],
    grads: list[Tensor],
    exp_avgs: list[Tensor],
    exp_avg_sqs: list[Tensor],
    state_steps: list[Tensor],
):
    for i, param in enumerate(params):
        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step_t = state_steps[i]

        step_t.add_(1)

        exp_avg.mul_(0.9).add_(grad * 0.1)
        exp_avg_sq.mul_(0.999).addcmul_(grad, grad, value=0.001)

        bias_correction1 = 1 - 0.9**step_t
        bias_correction2 = 1 - 0.999**step_t

        step_size = 0.001 / bias_correction1

        bias_correction2_sqrt = bias_correction2.sqrt()
        denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt) + 1e-8

        param.addcdiv_(exp_avg, denom, value=-step_size)


dim = 1024
n_param = 2
with torch.device('cuda'):
    params = [torch.randn(dim, dim) for _ in range(n_param)]
    grads = [torch.randn_like(param) for param in params]
    exp_avgs = [torch.zeros_like(param) for param in params]
    exp_avg_sqs = [torch.zeros_like(param) for param in params]
    steps = [torch.tensor(0.0) for param in params]

jitted_step = thunder.jit(_single_tensor_adam, disable_inplace_copy_check=True)
jitted_step(params, grads, exp_avgs, exp_avg_sqs, steps)

print(thunder.last_traces(jitted_step)[-1])
trace:
def computation(res, elem, grad, t_1_1, exp_avg, t_2_1, exp_avg_sq, t_3_1, step_t, t_4_1):
  # res: "cuda:0 f32[1024, 1024]"
  # elem: "cuda:0 f32[1024, 1024]"
  # grad: "cuda:0 f32[1024, 1024]"
  # t_1_1: "cuda:0 f32[1024, 1024]"
  # exp_avg: "cuda:0 f32[1024, 1024]"
  # t_2_1: "cuda:0 f32[1024, 1024]"
  # exp_avg_sq: "cuda:0 f32[1024, 1024]"
  # t_3_1: "cuda:0 f32[1024, 1024]"
  # step_t: "cuda:0 f32[]"
  # t_4_1: "cuda:0 f32[]"
  nvFusion0(elem, exp_avg, exp_avg_sq, grad, res, step_t, t_1_1, t_2_1, t_3_1, t_4_1)
    # t0 = prims.add(step_t, 1.0)  # t0: "cuda:0 f32[]"
    # t2 = prims.mul(exp_avg, 0.9)  # t2: "cuda:0 f32[1024, 1024]"
    # result = prims.mul(grad, 0.1)  # result: "cuda:0 f32[1024, 1024]"
    # t5 = prims.add(t2, result)  # t5: "cuda:0 f32[1024, 1024]"
    # t7 = prims.mul(exp_avg_sq, 0.999)  # t7: "cuda:0 f32[1024, 1024]"
    # t9 = prims.mul(grad, grad)  # t9: "cuda:0 f32[1024, 1024]"
    # t10 = prims.mul(0.001, t9)  # t10: "cuda:0 f32[1024, 1024]"
    # t11 = prims.add(t7, t10)  # t11: "cuda:0 f32[1024, 1024]"
    # b = prims.pow(0.9, t0)  # b: "cuda:0 f32[]"
    # bias_correction1 = prims.sub(1.0, b)  # bias_correction1: "cuda:0 f32[]"
    # t15 = prims.pow(0.999, t0)  # t15: "cuda:0 f32[]"
    # bias_correction2 = prims.sub(1.0, t15)  # bias_correction2: "cuda:0 f32[]"
    # step_size = prims.div(0.001, bias_correction1)  # step_size: "cuda:0 f32[]"
    # bias_correction2_sqrt = prims.sqrt(bias_correction2)  # bias_correction2_sqrt: "cuda:0 f32[]"
    # a = prims.sqrt(t11)  # a: "cuda:0 f32[1024, 1024]"
    # t20 = prims.broadcast_in_dim(bias_correction2_sqrt, (1024, 1024), ())  # t20: "cuda:0 f32[1024, 1024]"
    # t21 = prims.div(a, t20)  # t21: "cuda:0 f32[1024, 1024]"
    # denom = prims.add(t21, 1e-08)  # denom: "cuda:0 f32[1024, 1024]"
    # tos2 = prims.neg(step_size)  # tos2: "cuda:0 f32[]"
    # t24 = prims.div(t5, denom)  # t24: "cuda:0 f32[1024, 1024]"
    # t25 = prims.broadcast_in_dim(tos2, (1024, 1024), ())  # t25: "cuda:0 f32[1024, 1024]"
    # t26 = prims.mul(t25, t24)  # t26: "cuda:0 f32[1024, 1024]"
    # t27 = prims.add(res, t26)  # t27: "cuda:0 f32[1024, 1024]"
    # t29 = prims.add(t_4_1, 1.0)  # t29: "cuda:0 f32[]"
    # t31 = prims.mul(t_2_1, 0.9)  # t31: "cuda:0 f32[1024, 1024]"
    # t33 = prims.mul(t_1_1, 0.1)  # t33: "cuda:0 f32[1024, 1024]"
    # t34 = prims.add(t31, t33)  # t34: "cuda:0 f32[1024, 1024]"
    # t36 = prims.mul(t_3_1, 0.999)  # t36: "cuda:0 f32[1024, 1024]"
    # t38 = prims.mul(t_1_1, t_1_1)  # t38: "cuda:0 f32[1024, 1024]"
    # t39 = prims.mul(0.001, t38)  # t39: "cuda:0 f32[1024, 1024]"
    # t40 = prims.add(t36, t39)  # t40: "cuda:0 f32[1024, 1024]"
    # t42 = prims.pow(0.9, t29)  # t42: "cuda:0 f32[]"
    # t43 = prims.sub(1.0, t42)  # t43: "cuda:0 f32[]"
    # t44 = prims.pow(0.999, t29)  # t44: "cuda:0 f32[]"
    # t45 = prims.sub(1.0, t44)  # t45: "cuda:0 f32[]"
    # tos = prims.div(0.001, t43)  # tos: "cuda:0 f32[]"
    # t47 = prims.sqrt(t45)  # t47: "cuda:0 f32[]"
    # t48 = prims.sqrt(t40)  # t48: "cuda:0 f32[1024, 1024]"
    # t49 = prims.broadcast_in_dim(t47, (1024, 1024), ())  # t49: "cuda:0 f32[1024, 1024]"
    # t50 = prims.div(t48, t49)  # t50: "cuda:0 f32[1024, 1024]"
    # t51 = prims.add(t50, 1e-08)  # t51: "cuda:0 f32[1024, 1024]"
    # t52 = prims.neg(tos)  # t52: "cuda:0 f32[]"
    # t53 = prims.div(t34, t51)  # t53: "cuda:0 f32[1024, 1024]"
    # t54 = prims.broadcast_in_dim(t52, (1024, 1024), ())  # t54: "cuda:0 f32[1024, 1024]"
    # t55 = prims.mul(t54, t53)  # t55: "cuda:0 f32[1024, 1024]"
    # t56 = prims.add(elem, t55)  # t56: "cuda:0 f32[1024, 1024]"
    # prims.copy_(t0, step_t)
    # prims.copy_(t2, exp_avg)
    # prims.copy_(t7, exp_avg_sq)
    # prims.copy_(t27, res)
    # prims.copy_(t29, t_4_1)
    # prims.copy_(t31, t_2_1)
    # prims.copy_(t36, t_3_1)
    # prims.copy_(t56, elem)
  del elem, exp_avg, exp_avg_sq, grad, res, step_t, t_1_1, t_2_1, t_3_1, t_4_1
  return None

However its result is not equal to

def _single_tensor_adam(
    params: list[Tensor],
    grads: list[Tensor],
    exp_avgs: list[Tensor],
    exp_avg_sqs: list[Tensor],
    state_steps: list[Tensor],
):
    for i, param in enumerate(params):
        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step_t = state_steps[i]

        # update step
        step_t = thunder.prims.copy_(step_t + 1, step_t)

        # Decay the first and second moment running average coefficient
        exp_avg = thunder.prims.copy_(exp_avg * (0.9) + grad * (0.1), exp_avg)
        # exp_avg = exp_avg * (beta1) + grad * (1 - beta1)

        exp_avg_sq = thunder.prims.copy_((exp_avg_sq * 0.999) + (0.001) * grad * grad, exp_avg_sq)
        # exp_avg_sq = (exp_avg_sq * beta2) + (1 - beta2) * grad * grad

        step = step_t
        bias_correction1 = 1 - 0.9**step
        bias_correction2 = 1 - 0.999**step

        step_size = 0.001 / bias_correction1

        bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)

        denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt) + (1e-8)

        param = thunder.prims.copy_(param + (-step_size) * exp_avg / denom, param)
trace:
def computation(res, elem, grad, t_1_1, exp_avg, t_2_1, exp_avg_sq, t_3_1, step_t, t_4_1):
  # res: "cuda:0 f32[1024, 1024]"
  # elem: "cuda:0 f32[1024, 1024]"
  # grad: "cuda:0 f32[1024, 1024]"
  # t_1_1: "cuda:0 f32[1024, 1024]"
  # exp_avg: "cuda:0 f32[1024, 1024]"
  # t_2_1: "cuda:0 f32[1024, 1024]"
  # exp_avg_sq: "cuda:0 f32[1024, 1024]"
  # t_3_1: "cuda:0 f32[1024, 1024]"
  # step_t: "cuda:0 f32[]"
  # t_4_1: "cuda:0 f32[]"
  nvFusion0(elem, exp_avg, exp_avg_sq, grad, res, step_t, t_1_1, t_2_1, t_3_1, t_4_1)
    # result = prims.add(step_t, 1.0)  # result: "cuda:0 f32[]"
    # step = prims.copy_(result, step_t)  # step: "cuda:0 f32[]"
    # a = prims.mul(exp_avg, 0.9)  # a: "cuda:0 f32[1024, 1024]"
    # b = prims.mul(grad, 0.1)  # b: "cuda:0 f32[1024, 1024]"
    # t4 = prims.add(a, b)  # t4: "cuda:0 f32[1024, 1024]"
    # t5 = prims.copy_(t4, exp_avg)  # t5: "cuda:0 f32[1024, 1024]"
    # t6 = prims.mul(exp_avg_sq, 0.999)  # t6: "cuda:0 f32[1024, 1024]"
    # t7 = prims.mul(0.001, grad)  # t7: "cuda:0 f32[1024, 1024]"
    # t8 = prims.mul(t7, grad)  # t8: "cuda:0 f32[1024, 1024]"
    # t9 = prims.add(t6, t8)  # t9: "cuda:0 f32[1024, 1024]"
    # t10 = prims.copy_(t9, exp_avg_sq)  # t10: "cuda:0 f32[1024, 1024]"
    # t11 = prims.pow(0.9, step)  # t11: "cuda:0 f32[]"
    # bias_correction1 = prims.sub(1.0, t11)  # bias_correction1: "cuda:0 f32[]"
    # t13 = prims.pow(0.999, step)  # t13: "cuda:0 f32[]"
    # bias_correction2 = prims.sub(1.0, t13)  # bias_correction2: "cuda:0 f32[]"
    # step_size = prims.div(0.001, bias_correction1)  # step_size: "cuda:0 f32[]"
    # bias_correction2_sqrt = prims.sqrt(bias_correction2)  # bias_correction2_sqrt: "cuda:0 f32[]"
    # t17 = prims.sqrt(t10)  # t17: "cuda:0 f32[1024, 1024]"
    # t18 = prims.broadcast_in_dim(bias_correction2_sqrt, (1024, 1024), ())  # t18: "cuda:0 f32[1024, 1024]"
    # t19 = prims.div(t17, t18)  # t19: "cuda:0 f32[1024, 1024]"
    # denom = prims.add(t19, 1e-08)  # denom: "cuda:0 f32[1024, 1024]"
    # t21 = prims.neg(step_size)  # t21: "cuda:0 f32[]"
    # t22 = prims.broadcast_in_dim(t21, (1024, 1024), ())  # t22: "cuda:0 f32[1024, 1024]"
    # t23 = prims.mul(t22, t5)  # t23: "cuda:0 f32[1024, 1024]"
    # t24 = prims.div(t23, denom)  # t24: "cuda:0 f32[1024, 1024]"
    # t25 = prims.add(res, t24)  # t25: "cuda:0 f32[1024, 1024]"
    # prims.copy_(t25, res)
    # t27 = prims.add(t_4_1, 1.0)  # t27: "cuda:0 f32[]"
    # t28 = prims.copy_(t27, t_4_1)  # t28: "cuda:0 f32[]"
    # t29 = prims.mul(t_2_1, 0.9)  # t29: "cuda:0 f32[1024, 1024]"
    # t30 = prims.mul(t_1_1, 0.1)  # t30: "cuda:0 f32[1024, 1024]"
    # t31 = prims.add(t29, t30)  # t31: "cuda:0 f32[1024, 1024]"
    # t32 = prims.copy_(t31, t_2_1)  # t32: "cuda:0 f32[1024, 1024]"
    # t33 = prims.mul(t_3_1, 0.999)  # t33: "cuda:0 f32[1024, 1024]"
    # t34 = prims.mul(0.001, t_1_1)  # t34: "cuda:0 f32[1024, 1024]"
    # t35 = prims.mul(t34, t_1_1)  # t35: "cuda:0 f32[1024, 1024]"
    # t36 = prims.add(t33, t35)  # t36: "cuda:0 f32[1024, 1024]"
    # t37 = prims.copy_(t36, t_3_1)  # t37: "cuda:0 f32[1024, 1024]"
    # t38 = prims.pow(0.9, t28)  # t38: "cuda:0 f32[]"
    # t39 = prims.sub(1.0, t38)  # t39: "cuda:0 f32[]"
    # t40 = prims.pow(0.999, t28)  # t40: "cuda:0 f32[]"
    # x = prims.sub(1.0, t40)  # x: "cuda:0 f32[]"
    # tos = prims.div(0.001, t39)  # tos: "cuda:0 f32[]"
    # t43 = prims.sqrt(x)  # t43: "cuda:0 f32[]"
    # t44 = prims.sqrt(t37)  # t44: "cuda:0 f32[1024, 1024]"
    # t45 = prims.broadcast_in_dim(t43, (1024, 1024), ())  # t45: "cuda:0 f32[1024, 1024]"
    # t46 = prims.div(t44, t45)  # t46: "cuda:0 f32[1024, 1024]"
    # t47 = prims.add(t46, 1e-08)  # t47: "cuda:0 f32[1024, 1024]"
    # t48 = prims.neg(tos)  # t48: "cuda:0 f32[]"
    # t49 = prims.broadcast_in_dim(t48, (1024, 1024), ())  # t49: "cuda:0 f32[1024, 1024]"
    # t50 = prims.mul(t49, t32)  # t50: "cuda:0 f32[1024, 1024]"
    # t51 = prims.div(t50, t47)  # t51: "cuda:0 f32[1024, 1024]"
    # t52 = prims.add(elem, t51)  # t52: "cuda:0 f32[1024, 1024]"
    # prims.copy_(t52, elem)
  del elem, exp_avg, exp_avg_sq, grad, res, step_t, t_1_1, t_2_1, t_3_1, t_4_1
  return None

The latter version gives the correct result (the same result as torch.optim.Adam(foreach=False) which calls the original _single_tensor_adam).

shino16 avatar Jul 11 '24 03:07 shino16

Not exactly a particular use-case, but I think it might be helpful to look at test_functionalization.py from PyTorch to understand the cases that are tested and see whether they work with thunder.

kshitij12345 avatar Jul 22 '24 13:07 kshitij12345

https://github.com/Lightning-AI/lightning-thunder/issues/657#issuecomment-2208317779

Sorry for jumping on this one late.

In the comment above

this case works, but the signature of computation(res, elem, step_t, t_1_1) does not look great given that the input g's signature is params: list[torch.Tensor], steps: list[torch.Tensor].

Is there anything wrong with that? prologue trace is supposed to flatten/clean up container objects. Since tensors are just reference to buffers, flattening it shouldn't interfere with the actual buffer update.

jjsjann123 avatar Aug 07 '24 19:08 jjsjann123

Is there anything wrong with that?

No, but it would look better if variable names were like params_0, grads_1, exp_avgs_2

crcrpar avatar Aug 07 '24 20:08 crcrpar

I'm not pushing to support this case, but maybe one of the example that could use a loud error when inplace is applied on an ambiguous view. (or maybe just back out from trying to remove the in-place update here).

If functionalization is skipped, the program executes correctly.

import thunder
import torch
dtype = torch.float32


def foo(a):
  b = a.reshape(6, 4)
  b.add_(1)
  return b
 
jfoo = thunder.jit(foo, skip_inplace_functionalization=False)
 
a = torch.ones(2, 3, 4, device="cuda")
a_ref = a.clone()
 
out = jfoo(a)
out_ref = foo(a_ref)
 
assert(a.allclose(a_ref))
assert(out.allclose(out_ref))
 
print("====================")
print("--- run1 ---")
print("\n\tprologue:\n", thunder.last_prologue_traces(jfoo)[0])
print("\n\tcompute:\n", thunder.last_traces(jfoo)[0])
print("\n\tcompute last:\n", thunder.last_traces(jfoo)[-1])

jjsjann123 avatar Aug 07 '24 20:08 jjsjann123

@crcrpar, does the example in https://github.com/Lightning-AI/lightning-thunder/issues/657#issuecomment-2208655632 work correctly now with nvFuser executor? If so can you please update the comment linking the pull request that fixed the problem? I've tried now on the CPU device in Colab and here's the trace that I got:

def computation(a):
  # a: "cpu f32[4]"
  t0 = torch.exp(a)  # t0: "cpu f32[4]"
    # t0 = ltorch.exp(a)  # t0: "cpu f32[4]"
      # t0 = prims.exp(a)  # t0: "cpu f32[4]"
  t2 = torch.sin(t0)  # t2: "cpu f32[4]"
    # t2 = ltorch.sin(t0)  # t2: "cpu f32[4]"
      # t2 = prims.sin(t0)  # t2: "cpu f32[4]"
  del t0
  copy_(t2, a)
  del t2
  return a

Now a is correctly updated from t2 and not t1 as recorded in https://github.com/Lightning-AI/lightning-thunder/issues/657#issuecomment-2208655632.

UPDATE: I tried on a GPU instance in Colab and now hitting this assertion https://github.com/Lightning-AI/lightning-thunder/blob/d6995e585e4c58d021b2af2bdd613fc7a949f3d4/thunder/executors/nvfuserex_impl.py#L724

The root cause is probably the same as in https://github.com/Lightning-AI/lightning-thunder/issues/912.

It happens even with a CPU input (click for code):
import torch
import thunder

def f(a):
    return a.exp_().sin_()

x = torch.randn(4, device="cpu", requires_grad=False)
x_ref = x.clone().detach().requires_grad_(False)

y_ref = f(x_ref)
jitted = thunder.jit(f)
y = jitted(x)

IvanYashchuk avatar Aug 12 '24 12:08 IvanYashchuk

Re: https://github.com/Lightning-AI/lightning-thunder/issues/657#issuecomment-2274293433.

I created a separate issue to discuss reshape+in-place: https://github.com/Lightning-AI/lightning-thunder/issues/957. I invite everyone to put down their thoughts on whether we should support this pattern.

IvanYashchuk avatar Aug 12 '24 13:08 IvanYashchuk