Add `prims.copy_to_out_`
Fixes #1173. This adds a new primitive prims.copy_to_out_(computed, *, out), which is used instead of prims.copy_ for the update in in-place ops. Unlike prims.copy_, prims.copy_to_out_(computed, *, out) assumes that computed is not used by subsequent ops, so out can simply alias computed.
main: 63887b3d020932d5a45ade4cac96f6376e59d602 #1193: 2781a201044ec3671f685a493052111a18c45e53
| compilation (s) | execution (ms) | |
|---|---|---|
| eager | 0.0 | 10.86 |
torch.compile(adam.step, backend=thunder) main |
94.0 | 11.52 |
torch.compile(adam.step, backend=thunder), #1193 |
48.0 | 6.00 |
The rule with prims.copy_to_out_ is (link):
# WARN: `computed` must be an intermediate tensor used solely for this `copy_to_out_` call,
# e.g. copy_to_out_(add(a, b), out=a). Thunder does not guarantee that `computed` remains to have
# the correct values after copy_to_out_ returns. For general-purpose copy, use prims.copy_ instead
This rule comes from the fact that any copies onto out will be propagated to its alias, computed.
To prevent users from using prims.copy_to_out_ inappropriately, I made the sanity check on prims.copy_to_out_ rather conservative. When enabled, it raises an error when computed is
- used as an input to another ops within the nvFuser region
- defined outside of the region (because it may be modified in-place), or
- used outside of the region (because it may not have the correct value).
See tests for examples.
As of now, the test thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-llama1-like] does NOT pass.
Minimal reproducible example:
import torch
import thunder
@partial(thunder.jit, disable_inplace_copy_check=True)
def f(q, k, v, mask, idx, src):
q.index_copy_(2, idx, src)
k.index_copy_(2, idx, src)
return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
q = torch.randn((1, 4, 2, 16), device='cuda', dtype=torch.float32)
k = torch.randn((1, 4, 3, 16), device='cuda', dtype=torch.float32)
v = torch.randn((1, 4, 3, 16), device='cuda', dtype=torch.float32)
mask = torch.ones((1, 1, 2, 3), device='cuda', dtype=torch.bool)
idx = torch.arange(2).to(device='cuda')
src = torch.randn((1, 4, 2, 16), device='cuda', dtype=torch.float32)
f(q, k, v, mask, idx, src)
Execution trace:
def computation(q, k, v, mask, idx, src):
# q: "cuda:0 f32[1, 4, 2, 16]"
# k: "cuda:0 f32[1, 4, 3, 16]"
# v: "cuda:0 f32[1, 4, 3, 16]"
# mask: "cuda:0 b8[1, 1, 2, 3]"
# idx: "cuda:0 i64[2]"
# src: "cuda:0 f32[1, 4, 2, 16]"
t0 = torch.index_copy(q, 2, idx, src) # t0: "cuda:0 f32[1, 4, 2, 16]"
# t0 = ltorch.index_copy(q, 2, idx, src) # t0: "cuda:0 f32[1, 4, 2, 16]"
# t0 = prims.index_copy(q, idx, src, 2) # t0: "cuda:0 f32[1, 4, 2, 16]"
t2 = torch.index_copy(k, 2, idx, src) # t2: "cuda:0 f32[1, 4, 3, 16]"
# t2 = ltorch.index_copy(k, 2, idx, src) # t2: "cuda:0 f32[1, 4, 3, 16]"
# t2 = prims.index_copy(k, idx, src, 2) # t2: "cuda:0 f32[1, 4, 3, 16]"
[t1, t3] = nvFusion0(t0, q, t2, k)
# t1 = prims.copy_to_out_(t0, out=q) # t1: "cuda:0 f32[1, 4, 2, 16]"
# t3 = prims.copy_to_out_(t2, out=k) # t3: "cuda:0 f32[1, 4, 3, 16]"
del q, k
(t19, _, _, _) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t0, t2, v, mask, 0.0, False, None)
del t0, t2
return t19
_inplace_copy_sanity_check raises an error for this, because
t0is passed to the sdpa operator, and- if
nvFusion0had a copy ontoqin the form ofprims.copy_(XX, copy_to=q),XXis propagated tot0.
Note that, before passing the trace to the nvFuser executor, prims.copy_to_out_(t0, out=q) is put after the sdpaex operator thanks to functionalization.
Trace just before nvFuser
def computation(q, k, v, mask, idx, src):
# q: "cuda:0 f32[1, 4, 2, 16]"
# k: "cuda:0 f32[1, 4, 3, 16]"
# v: "cuda:0 f32[1, 4, 3, 16]"
# mask: "cuda:0 b8[1, 1, 2, 3]"
# idx: "cuda:0 i64[2]"
# src: "cuda:0 f32[1, 4, 2, 16]"
# Functionalized from `t1 = index_copy_(q,2,idx,src)`
t0 = ltorch.index_copy(q, 2, idx, src) # t0: "cuda:0 f32[1, 4, 2, 16]"
# t0 = prims.index_copy(q, idx, src, 2) # t0: "cuda:0 f32[1, 4, 2, 16]"
# Functionalized from `t3 = index_copy_(k,2,idx,src)`
t2 = ltorch.index_copy(k, 2, idx, src) # t2: "cuda:0 f32[1, 4, 3, 16]"
# t2 = prims.index_copy(k, idx, src, 2) # t2: "cuda:0 f32[1, 4, 3, 16]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:60: return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
# ['t1', 't3'] are replaced by ['t0', 't2'], respectively
t19 = ltorch.scaled_dot_product_attention(t0, t2, v, mask, 0.0, False, scale=None) # t19: "cuda:0 f32[1, 4, 2, 16]"
# subsymbols omitted
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:58: q.index_copy_(2, idx, src)
t1 = prims.copy_to_out_(t0, out=q) # t1: "cuda:0 f32[1, 4, 2, 16]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:59: k.index_copy_(2, idx, src)
t3 = prims.copy_to_out_(t2, out=k) # t3: "cuda:0 f32[1, 4, 3, 16]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:60: return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
return {'output': t19, 'flat_args': [t1, t3, v, mask, idx, src]}
Possible solutions
- Functionalize
prims.copy_in a way such that there will never be multiple copies onto the same tensor. If we assume this, we can be sure thatqnever changes afterprims.copy_to_out_(t0, out=q)in the previous example. - Preserve order between operations involving
t0andprims.copy_to_out_(t0, out=q). This order is enforced by functionalization (link), but they do not establish dependency relationship in terms of outputs and inputs. This would lead to exposing fixes only for nvFuser.
I disabled the sanity check for test_litgpt_variants_kvcache.
Note that, before passing the trace to the nvFuser executor, prims.copy_to_out_(t0, out=q) is put after the sdpaex operator thanks to functionalization.
Ugh.
Could it be that inplace_copy_ is particular here?
I wonder if something along the lines of @IvanYashchuk 's planned primitive for dataflow healing would be useful.
The same happens when we use Tensor.add_ instead of Tensor.index_copy_ too, so this is a generic issue.
My internship period is about to end, so I can no longer spend much time on this issue. Maybe we can close this PR for now and wait for functionalization of copy_ or dependency establishment to mature, and merge #1177 or take other strategies in the meantime if the performance regression in #1173 is urgent.
closing for now @crcrpar please reopen as you see fit.