lightning-thunder
lightning-thunder copied to clipboard
Desired inplace idioms (add yours here)
We do support some inplace thanks to @crcrpar 's great patches. This has enable running additional models already, but is limited to a few cases of particular interest.
However, we would want to see and prioritize other important inplace ideoms, so we would appreciate if you chimed in with needs you or your favourite model has.
Please do try to post
- a minimal code snippet,
- if you are aware of a particular model needing this, please also add this.
We don't want to support all corner cases of inplace (eg "does reshape produce a view or not") but we do care a lot about enabling users and models. Thank you!
P.S.: Also "don't need to support but should error" is a thing.
If you have traces as well as an idiom, then could you please paste thunder.last_traces(jitted)[1:3], that are the traces generated by functionalization?
thunder.jit the following function with nvfuerex fails with the message below.
By moving the copy for a += b to the end of a trace and replacing a += b with t0 = a + b, I expect it to work.
import torch
def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
a += b
c = torch.exp(a)
d = torch.tanh(b)
e = c.view(-1)
e.add_(d.flatten())
d.div_(a)
return c, d, e / 2.0
Traceback (most recent call last):
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 55, in <module>
main()
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 24, in main
c, d, e = jit_f(a, b)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 676, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 223, in cache_info_wrapper
res = fn(*args, **kwargs)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 615, in get_computation_and_inputs
thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 86, in _inplace_copy_sanity_check
check(copy_to_out, "output")
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 80, in check
raise NotImplementedError(
NotImplementedError: t8 = prims.div(d, t1) # t8: "cuda:0 f32[2, 2]" trying to use t1 (the output of 'prims.copy_') as input, which is not safe. There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
The failing loudly (and not obscurely) part is super important. Let's keep this issue to keep track of this aspect.
I would prioritize enumerating corner cases as much as we can, and creating good unhappy paths for them (with sane error messages). Then we address them and lift limitations where it makes sense. WDYT @crcrpar ?
Optimizer's inplace updates might be of concern when designing a scheme as Adam already carries 3 copies of model state algorithmically.
this case works, but the signature of computation(res, elem, step_t, t_1_1) does not look great given that the input g's signature is params: list[torch.Tensor], steps: list[torch.Tensor].
@thunder.jit
def g(params, steps):
for i, param in enumerate(params):
step_t = steps[i]
step_t.add_(1)
param.add_(step_t)
return params
with torch.device('cuda'):
params = [torch.tensor(0), torch.tensor(0)]
steps = [torch.tensor(0), torch.tensor(0)]
g(params, steps)
print(params, steps)
print(thunder.last_traces(g)[-1])
([tensor(1, device='cuda:0'), tensor(1, device='cuda:0')],
[tensor(1, device='cuda:0'), tensor(1, device='cuda:0')])
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(res, elem, step_t, t_1_1):
# res: "cuda:0 i64[]"
# elem: "cuda:0 i64[]"
# step_t: "cuda:0 i64[]"
# t_1_1: "cuda:0 i64[]"
nvFusion0(elem, res, step_t, t_1_1)
# t0 = prims.add(step_t, 1) # t0: "cuda:0 i64[]"
# t2 = prims.add(res, t0) # t2: "cuda:0 i64[]"
# t4 = prims.add(t_1_1, 1) # t4: "cuda:0 i64[]"
# t6 = prims.add(elem, t4) # t6: "cuda:0 i64[]"
# prims.copy_(t0, step_t)
# prims.copy_(t2, res)
# prims.copy_(t4, t_1_1)
# prims.copy_(t6, elem)
del elem, res, step_t, t_1_1
return [_torch_Tensor_0, _torch_Tensor_1]
multiple in-place whose operand is the func's arg is not appropriately handled.
import torch
import thunder
def f(a):
return a.exp_().sin_()
if __name__ == "__main__":
x = torch.randn(4, device="cuda", requires_grad=False)
x_ref = x.clone().detach().requires_grad_(False)
y_ref = f(x_ref)
jitted = thunder.jit(f)
y = jitted(x)
print(thunder.last_traces(jitted)[-1])
gives
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a):
# a: "cuda:0 f32[4]"
[t2] = nvFusion0(a)
# t0 = prims.exp(a) # t0: "cuda:0 f32[4]"
# t2 = prims.sin(t0) # t2: "cuda:0 f32[4]"
# prims.copy_(t0, a)
del a
return t2
To make a jittable implementation of Adam, I want this to be sound.
def _single_tensor_adam(
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
):
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
step_t.add_(1)
exp_avg.mul_(0.9).add_(grad * 0.1)
exp_avg_sq.mul_(0.999).addcmul_(grad, grad, value=0.001)
bias_correction1 = 1 - 0.9**step_t
bias_correction2 = 1 - 0.999**step_t
step_size = 0.001 / bias_correction1
bias_correction2_sqrt = bias_correction2.sqrt()
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt) + 1e-8
param.addcdiv_(exp_avg, denom, value=-step_size)
dim = 1024
n_param = 2
with torch.device('cuda'):
params = [torch.randn(dim, dim) for _ in range(n_param)]
grads = [torch.randn_like(param) for param in params]
exp_avgs = [torch.zeros_like(param) for param in params]
exp_avg_sqs = [torch.zeros_like(param) for param in params]
steps = [torch.tensor(0.0) for param in params]
jitted_step = thunder.jit(_single_tensor_adam, disable_inplace_copy_check=True)
jitted_step(params, grads, exp_avgs, exp_avg_sqs, steps)
print(thunder.last_traces(jitted_step)[-1])
trace:
def computation(res, elem, grad, t_1_1, exp_avg, t_2_1, exp_avg_sq, t_3_1, step_t, t_4_1):
# res: "cuda:0 f32[1024, 1024]"
# elem: "cuda:0 f32[1024, 1024]"
# grad: "cuda:0 f32[1024, 1024]"
# t_1_1: "cuda:0 f32[1024, 1024]"
# exp_avg: "cuda:0 f32[1024, 1024]"
# t_2_1: "cuda:0 f32[1024, 1024]"
# exp_avg_sq: "cuda:0 f32[1024, 1024]"
# t_3_1: "cuda:0 f32[1024, 1024]"
# step_t: "cuda:0 f32[]"
# t_4_1: "cuda:0 f32[]"
nvFusion0(elem, exp_avg, exp_avg_sq, grad, res, step_t, t_1_1, t_2_1, t_3_1, t_4_1)
# t0 = prims.add(step_t, 1.0) # t0: "cuda:0 f32[]"
# t2 = prims.mul(exp_avg, 0.9) # t2: "cuda:0 f32[1024, 1024]"
# result = prims.mul(grad, 0.1) # result: "cuda:0 f32[1024, 1024]"
# t5 = prims.add(t2, result) # t5: "cuda:0 f32[1024, 1024]"
# t7 = prims.mul(exp_avg_sq, 0.999) # t7: "cuda:0 f32[1024, 1024]"
# t9 = prims.mul(grad, grad) # t9: "cuda:0 f32[1024, 1024]"
# t10 = prims.mul(0.001, t9) # t10: "cuda:0 f32[1024, 1024]"
# t11 = prims.add(t7, t10) # t11: "cuda:0 f32[1024, 1024]"
# b = prims.pow(0.9, t0) # b: "cuda:0 f32[]"
# bias_correction1 = prims.sub(1.0, b) # bias_correction1: "cuda:0 f32[]"
# t15 = prims.pow(0.999, t0) # t15: "cuda:0 f32[]"
# bias_correction2 = prims.sub(1.0, t15) # bias_correction2: "cuda:0 f32[]"
# step_size = prims.div(0.001, bias_correction1) # step_size: "cuda:0 f32[]"
# bias_correction2_sqrt = prims.sqrt(bias_correction2) # bias_correction2_sqrt: "cuda:0 f32[]"
# a = prims.sqrt(t11) # a: "cuda:0 f32[1024, 1024]"
# t20 = prims.broadcast_in_dim(bias_correction2_sqrt, (1024, 1024), ()) # t20: "cuda:0 f32[1024, 1024]"
# t21 = prims.div(a, t20) # t21: "cuda:0 f32[1024, 1024]"
# denom = prims.add(t21, 1e-08) # denom: "cuda:0 f32[1024, 1024]"
# tos2 = prims.neg(step_size) # tos2: "cuda:0 f32[]"
# t24 = prims.div(t5, denom) # t24: "cuda:0 f32[1024, 1024]"
# t25 = prims.broadcast_in_dim(tos2, (1024, 1024), ()) # t25: "cuda:0 f32[1024, 1024]"
# t26 = prims.mul(t25, t24) # t26: "cuda:0 f32[1024, 1024]"
# t27 = prims.add(res, t26) # t27: "cuda:0 f32[1024, 1024]"
# t29 = prims.add(t_4_1, 1.0) # t29: "cuda:0 f32[]"
# t31 = prims.mul(t_2_1, 0.9) # t31: "cuda:0 f32[1024, 1024]"
# t33 = prims.mul(t_1_1, 0.1) # t33: "cuda:0 f32[1024, 1024]"
# t34 = prims.add(t31, t33) # t34: "cuda:0 f32[1024, 1024]"
# t36 = prims.mul(t_3_1, 0.999) # t36: "cuda:0 f32[1024, 1024]"
# t38 = prims.mul(t_1_1, t_1_1) # t38: "cuda:0 f32[1024, 1024]"
# t39 = prims.mul(0.001, t38) # t39: "cuda:0 f32[1024, 1024]"
# t40 = prims.add(t36, t39) # t40: "cuda:0 f32[1024, 1024]"
# t42 = prims.pow(0.9, t29) # t42: "cuda:0 f32[]"
# t43 = prims.sub(1.0, t42) # t43: "cuda:0 f32[]"
# t44 = prims.pow(0.999, t29) # t44: "cuda:0 f32[]"
# t45 = prims.sub(1.0, t44) # t45: "cuda:0 f32[]"
# tos = prims.div(0.001, t43) # tos: "cuda:0 f32[]"
# t47 = prims.sqrt(t45) # t47: "cuda:0 f32[]"
# t48 = prims.sqrt(t40) # t48: "cuda:0 f32[1024, 1024]"
# t49 = prims.broadcast_in_dim(t47, (1024, 1024), ()) # t49: "cuda:0 f32[1024, 1024]"
# t50 = prims.div(t48, t49) # t50: "cuda:0 f32[1024, 1024]"
# t51 = prims.add(t50, 1e-08) # t51: "cuda:0 f32[1024, 1024]"
# t52 = prims.neg(tos) # t52: "cuda:0 f32[]"
# t53 = prims.div(t34, t51) # t53: "cuda:0 f32[1024, 1024]"
# t54 = prims.broadcast_in_dim(t52, (1024, 1024), ()) # t54: "cuda:0 f32[1024, 1024]"
# t55 = prims.mul(t54, t53) # t55: "cuda:0 f32[1024, 1024]"
# t56 = prims.add(elem, t55) # t56: "cuda:0 f32[1024, 1024]"
# prims.copy_(t0, step_t)
# prims.copy_(t2, exp_avg)
# prims.copy_(t7, exp_avg_sq)
# prims.copy_(t27, res)
# prims.copy_(t29, t_4_1)
# prims.copy_(t31, t_2_1)
# prims.copy_(t36, t_3_1)
# prims.copy_(t56, elem)
del elem, exp_avg, exp_avg_sq, grad, res, step_t, t_1_1, t_2_1, t_3_1, t_4_1
return None
However its result is not equal to
def _single_tensor_adam(
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
):
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# update step
step_t = thunder.prims.copy_(step_t + 1, step_t)
# Decay the first and second moment running average coefficient
exp_avg = thunder.prims.copy_(exp_avg * (0.9) + grad * (0.1), exp_avg)
# exp_avg = exp_avg * (beta1) + grad * (1 - beta1)
exp_avg_sq = thunder.prims.copy_((exp_avg_sq * 0.999) + (0.001) * grad * grad, exp_avg_sq)
# exp_avg_sq = (exp_avg_sq * beta2) + (1 - beta2) * grad * grad
step = step_t
bias_correction1 = 1 - 0.9**step
bias_correction2 = 1 - 0.999**step
step_size = 0.001 / bias_correction1
bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt) + (1e-8)
param = thunder.prims.copy_(param + (-step_size) * exp_avg / denom, param)
trace:
def computation(res, elem, grad, t_1_1, exp_avg, t_2_1, exp_avg_sq, t_3_1, step_t, t_4_1):
# res: "cuda:0 f32[1024, 1024]"
# elem: "cuda:0 f32[1024, 1024]"
# grad: "cuda:0 f32[1024, 1024]"
# t_1_1: "cuda:0 f32[1024, 1024]"
# exp_avg: "cuda:0 f32[1024, 1024]"
# t_2_1: "cuda:0 f32[1024, 1024]"
# exp_avg_sq: "cuda:0 f32[1024, 1024]"
# t_3_1: "cuda:0 f32[1024, 1024]"
# step_t: "cuda:0 f32[]"
# t_4_1: "cuda:0 f32[]"
nvFusion0(elem, exp_avg, exp_avg_sq, grad, res, step_t, t_1_1, t_2_1, t_3_1, t_4_1)
# result = prims.add(step_t, 1.0) # result: "cuda:0 f32[]"
# step = prims.copy_(result, step_t) # step: "cuda:0 f32[]"
# a = prims.mul(exp_avg, 0.9) # a: "cuda:0 f32[1024, 1024]"
# b = prims.mul(grad, 0.1) # b: "cuda:0 f32[1024, 1024]"
# t4 = prims.add(a, b) # t4: "cuda:0 f32[1024, 1024]"
# t5 = prims.copy_(t4, exp_avg) # t5: "cuda:0 f32[1024, 1024]"
# t6 = prims.mul(exp_avg_sq, 0.999) # t6: "cuda:0 f32[1024, 1024]"
# t7 = prims.mul(0.001, grad) # t7: "cuda:0 f32[1024, 1024]"
# t8 = prims.mul(t7, grad) # t8: "cuda:0 f32[1024, 1024]"
# t9 = prims.add(t6, t8) # t9: "cuda:0 f32[1024, 1024]"
# t10 = prims.copy_(t9, exp_avg_sq) # t10: "cuda:0 f32[1024, 1024]"
# t11 = prims.pow(0.9, step) # t11: "cuda:0 f32[]"
# bias_correction1 = prims.sub(1.0, t11) # bias_correction1: "cuda:0 f32[]"
# t13 = prims.pow(0.999, step) # t13: "cuda:0 f32[]"
# bias_correction2 = prims.sub(1.0, t13) # bias_correction2: "cuda:0 f32[]"
# step_size = prims.div(0.001, bias_correction1) # step_size: "cuda:0 f32[]"
# bias_correction2_sqrt = prims.sqrt(bias_correction2) # bias_correction2_sqrt: "cuda:0 f32[]"
# t17 = prims.sqrt(t10) # t17: "cuda:0 f32[1024, 1024]"
# t18 = prims.broadcast_in_dim(bias_correction2_sqrt, (1024, 1024), ()) # t18: "cuda:0 f32[1024, 1024]"
# t19 = prims.div(t17, t18) # t19: "cuda:0 f32[1024, 1024]"
# denom = prims.add(t19, 1e-08) # denom: "cuda:0 f32[1024, 1024]"
# t21 = prims.neg(step_size) # t21: "cuda:0 f32[]"
# t22 = prims.broadcast_in_dim(t21, (1024, 1024), ()) # t22: "cuda:0 f32[1024, 1024]"
# t23 = prims.mul(t22, t5) # t23: "cuda:0 f32[1024, 1024]"
# t24 = prims.div(t23, denom) # t24: "cuda:0 f32[1024, 1024]"
# t25 = prims.add(res, t24) # t25: "cuda:0 f32[1024, 1024]"
# prims.copy_(t25, res)
# t27 = prims.add(t_4_1, 1.0) # t27: "cuda:0 f32[]"
# t28 = prims.copy_(t27, t_4_1) # t28: "cuda:0 f32[]"
# t29 = prims.mul(t_2_1, 0.9) # t29: "cuda:0 f32[1024, 1024]"
# t30 = prims.mul(t_1_1, 0.1) # t30: "cuda:0 f32[1024, 1024]"
# t31 = prims.add(t29, t30) # t31: "cuda:0 f32[1024, 1024]"
# t32 = prims.copy_(t31, t_2_1) # t32: "cuda:0 f32[1024, 1024]"
# t33 = prims.mul(t_3_1, 0.999) # t33: "cuda:0 f32[1024, 1024]"
# t34 = prims.mul(0.001, t_1_1) # t34: "cuda:0 f32[1024, 1024]"
# t35 = prims.mul(t34, t_1_1) # t35: "cuda:0 f32[1024, 1024]"
# t36 = prims.add(t33, t35) # t36: "cuda:0 f32[1024, 1024]"
# t37 = prims.copy_(t36, t_3_1) # t37: "cuda:0 f32[1024, 1024]"
# t38 = prims.pow(0.9, t28) # t38: "cuda:0 f32[]"
# t39 = prims.sub(1.0, t38) # t39: "cuda:0 f32[]"
# t40 = prims.pow(0.999, t28) # t40: "cuda:0 f32[]"
# x = prims.sub(1.0, t40) # x: "cuda:0 f32[]"
# tos = prims.div(0.001, t39) # tos: "cuda:0 f32[]"
# t43 = prims.sqrt(x) # t43: "cuda:0 f32[]"
# t44 = prims.sqrt(t37) # t44: "cuda:0 f32[1024, 1024]"
# t45 = prims.broadcast_in_dim(t43, (1024, 1024), ()) # t45: "cuda:0 f32[1024, 1024]"
# t46 = prims.div(t44, t45) # t46: "cuda:0 f32[1024, 1024]"
# t47 = prims.add(t46, 1e-08) # t47: "cuda:0 f32[1024, 1024]"
# t48 = prims.neg(tos) # t48: "cuda:0 f32[]"
# t49 = prims.broadcast_in_dim(t48, (1024, 1024), ()) # t49: "cuda:0 f32[1024, 1024]"
# t50 = prims.mul(t49, t32) # t50: "cuda:0 f32[1024, 1024]"
# t51 = prims.div(t50, t47) # t51: "cuda:0 f32[1024, 1024]"
# t52 = prims.add(elem, t51) # t52: "cuda:0 f32[1024, 1024]"
# prims.copy_(t52, elem)
del elem, exp_avg, exp_avg_sq, grad, res, step_t, t_1_1, t_2_1, t_3_1, t_4_1
return None
The latter version gives the correct result (the same result as torch.optim.Adam(foreach=False) which calls the original _single_tensor_adam).
Not exactly a particular use-case, but I think it might be helpful to look at test_functionalization.py from PyTorch to understand the cases that are tested and see whether they work with thunder.
https://github.com/Lightning-AI/lightning-thunder/issues/657#issuecomment-2208317779
Sorry for jumping on this one late.
In the comment above
this case works, but the signature of computation(res, elem, step_t, t_1_1) does not look great given that the input g's signature is params: list[torch.Tensor], steps: list[torch.Tensor].
Is there anything wrong with that? prologue trace is supposed to flatten/clean up container objects. Since tensors are just reference to buffers, flattening it shouldn't interfere with the actual buffer update.
Is there anything wrong with that?
No, but it would look better if variable names were like params_0, grads_1, exp_avgs_2
I'm not pushing to support this case, but maybe one of the example that could use a loud error when inplace is applied on an ambiguous view. (or maybe just back out from trying to remove the in-place update here).
If functionalization is skipped, the program executes correctly.
import thunder
import torch
dtype = torch.float32
def foo(a):
b = a.reshape(6, 4)
b.add_(1)
return b
jfoo = thunder.jit(foo, skip_inplace_functionalization=False)
a = torch.ones(2, 3, 4, device="cuda")
a_ref = a.clone()
out = jfoo(a)
out_ref = foo(a_ref)
assert(a.allclose(a_ref))
assert(out.allclose(out_ref))
print("====================")
print("--- run1 ---")
print("\n\tprologue:\n", thunder.last_prologue_traces(jfoo)[0])
print("\n\tcompute:\n", thunder.last_traces(jfoo)[0])
print("\n\tcompute last:\n", thunder.last_traces(jfoo)[-1])
@crcrpar, does the example in https://github.com/Lightning-AI/lightning-thunder/issues/657#issuecomment-2208655632 work correctly now with nvFuser executor? If so can you please update the comment linking the pull request that fixed the problem? I've tried now on the CPU device in Colab and here's the trace that I got:
def computation(a):
# a: "cpu f32[4]"
t0 = torch.exp(a) # t0: "cpu f32[4]"
# t0 = ltorch.exp(a) # t0: "cpu f32[4]"
# t0 = prims.exp(a) # t0: "cpu f32[4]"
t2 = torch.sin(t0) # t2: "cpu f32[4]"
# t2 = ltorch.sin(t0) # t2: "cpu f32[4]"
# t2 = prims.sin(t0) # t2: "cpu f32[4]"
del t0
copy_(t2, a)
del t2
return a
Now a is correctly updated from t2 and not t1 as recorded in https://github.com/Lightning-AI/lightning-thunder/issues/657#issuecomment-2208655632.
UPDATE: I tried on a GPU instance in Colab and now hitting this assertion https://github.com/Lightning-AI/lightning-thunder/blob/d6995e585e4c58d021b2af2bdd613fc7a949f3d4/thunder/executors/nvfuserex_impl.py#L724
The root cause is probably the same as in https://github.com/Lightning-AI/lightning-thunder/issues/912.
It happens even with a CPU input (click for code):
import torch
import thunder
def f(a):
return a.exp_().sin_()
x = torch.randn(4, device="cpu", requires_grad=False)
x_ref = x.clone().detach().requires_grad_(False)
y_ref = f(x_ref)
jitted = thunder.jit(f)
y = jitted(x)
Re: https://github.com/Lightning-AI/lightning-thunder/issues/657#issuecomment-2274293433.
I created a separate issue to discuss reshape+in-place: https://github.com/Lightning-AI/lightning-thunder/issues/957. I invite everyone to put down their thoughts on whether we should support this pattern.