lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Add `prims.copy_to_out_`

Open shino16 opened this issue 1 year ago • 4 comments

Fixes #1173. This adds a new primitive prims.copy_to_out_(computed, *, out), which is used instead of prims.copy_ for the update in in-place ops. Unlike prims.copy_, prims.copy_to_out_(computed, *, out) assumes that computed is not used by subsequent ops, so out can simply alias computed.

main: 63887b3d020932d5a45ade4cac96f6376e59d602 #1193: 2781a201044ec3671f685a493052111a18c45e53

compilation (s) execution (ms)
eager 0.0 10.86
torch.compile(adam.step, backend=thunder) main 94.0 11.52
torch.compile(adam.step, backend=thunder), #1193 48.0 6.00

The rule with prims.copy_to_out_ is (link):

# WARN: `computed` must be an intermediate tensor used solely for this `copy_to_out_` call,
# e.g. copy_to_out_(add(a, b), out=a). Thunder does not guarantee that `computed` remains to have
# the correct values after copy_to_out_ returns. For general-purpose copy, use prims.copy_ instead

This rule comes from the fact that any copies onto out will be propagated to its alias, computed.

To prevent users from using prims.copy_to_out_ inappropriately, I made the sanity check on prims.copy_to_out_ rather conservative. When enabled, it raises an error when computed is

  • used as an input to another ops within the nvFuser region
  • defined outside of the region (because it may be modified in-place), or
  • used outside of the region (because it may not have the correct value).

See tests for examples.

shino16 avatar Sep 24 '24 10:09 shino16

As of now, the test thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-llama1-like] does NOT pass.

Minimal reproducible example:

import torch
import thunder

@partial(thunder.jit, disable_inplace_copy_check=True)
def f(q, k, v, mask, idx, src):
    q.index_copy_(2, idx, src)
    k.index_copy_(2, idx, src)
    return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)

q = torch.randn((1, 4, 2, 16), device='cuda', dtype=torch.float32)
k = torch.randn((1, 4, 3, 16), device='cuda', dtype=torch.float32)
v = torch.randn((1, 4, 3, 16), device='cuda', dtype=torch.float32)
mask = torch.ones((1, 1, 2, 3), device='cuda', dtype=torch.bool)
idx = torch.arange(2).to(device='cuda')
src = torch.randn((1, 4, 2, 16), device='cuda', dtype=torch.float32)

f(q, k, v, mask, idx, src)

Execution trace:

def computation(q, k, v, mask, idx, src):
  # q: "cuda:0 f32[1, 4, 2, 16]"
  # k: "cuda:0 f32[1, 4, 3, 16]"
  # v: "cuda:0 f32[1, 4, 3, 16]"
  # mask: "cuda:0 b8[1, 1, 2, 3]"
  # idx: "cuda:0 i64[2]"
  # src: "cuda:0 f32[1, 4, 2, 16]"
  t0 = torch.index_copy(q, 2, idx, src)  # t0: "cuda:0 f32[1, 4, 2, 16]"
    # t0 = ltorch.index_copy(q, 2, idx, src)  # t0: "cuda:0 f32[1, 4, 2, 16]"
      # t0 = prims.index_copy(q, idx, src, 2)  # t0: "cuda:0 f32[1, 4, 2, 16]"
  t2 = torch.index_copy(k, 2, idx, src)  # t2: "cuda:0 f32[1, 4, 3, 16]"
    # t2 = ltorch.index_copy(k, 2, idx, src)  # t2: "cuda:0 f32[1, 4, 3, 16]"
      # t2 = prims.index_copy(k, idx, src, 2)  # t2: "cuda:0 f32[1, 4, 3, 16]"
  [t1, t3] = nvFusion0(t0, q, t2, k)
    # t1 = prims.copy_to_out_(t0, out=q)  # t1: "cuda:0 f32[1, 4, 2, 16]"
    # t3 = prims.copy_to_out_(t2, out=k)  # t3: "cuda:0 f32[1, 4, 3, 16]"
  del q, k
  (t19, _, _, _) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t0, t2, v, mask, 0.0, False, None)
  del t0, t2
  return t19

_inplace_copy_sanity_check raises an error for this, because

  • t0 is passed to the sdpa operator, and
  • if nvFusion0 had a copy onto q in the form of prims.copy_(XX, copy_to=q), XX is propagated to t0.

Note that, before passing the trace to the nvFuser executor, prims.copy_to_out_(t0, out=q) is put after the sdpaex operator thanks to functionalization.

Trace just before nvFuser
def computation(q, k, v, mask, idx, src):
  # q: "cuda:0 f32[1, 4, 2, 16]"
  # k: "cuda:0 f32[1, 4, 3, 16]"
  # v: "cuda:0 f32[1, 4, 3, 16]"
  # mask: "cuda:0 b8[1, 1, 2, 3]"
  # idx: "cuda:0 i64[2]"
  # src: "cuda:0 f32[1, 4, 2, 16]"
  # Functionalized from `t1 = index_copy_(q,2,idx,src)`
  t0 = ltorch.index_copy(q, 2, idx, src)  # t0: "cuda:0 f32[1, 4, 2, 16]"
    # t0 = prims.index_copy(q, idx, src, 2)  # t0: "cuda:0 f32[1, 4, 2, 16]"
  # Functionalized from `t3 = index_copy_(k,2,idx,src)`
  t2 = ltorch.index_copy(k, 2, idx, src)  # t2: "cuda:0 f32[1, 4, 3, 16]"
    # t2 = prims.index_copy(k, idx, src, 2)  # t2: "cuda:0 f32[1, 4, 3, 16]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:60:          return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
  # ['t1', 't3'] are replaced by ['t0', 't2'], respectively
  t19 = ltorch.scaled_dot_product_attention(t0, t2, v, mask, 0.0, False, scale=None)  # t19: "cuda:0 f32[1, 4, 2, 16]"
    # subsymbols omitted
  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:58:          q.index_copy_(2, idx, src)
  t1 = prims.copy_to_out_(t0, out=q)  # t1: "cuda:0 f32[1, 4, 2, 16]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:59:          k.index_copy_(2, idx, src)
  t3 = prims.copy_to_out_(t2, out=k)  # t3: "cuda:0 f32[1, 4, 3, 16]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:60:          return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
  return {'output': t19, 'flat_args': [t1, t3, v, mask, idx, src]}

Possible solutions

  • Functionalize prims.copy_ in a way such that there will never be multiple copies onto the same tensor. If we assume this, we can be sure that q never changes after prims.copy_to_out_(t0, out=q) in the previous example.
  • Preserve order between operations involving t0 and prims.copy_to_out_(t0, out=q). This order is enforced by functionalization (link), but they do not establish dependency relationship in terms of outputs and inputs. This would lead to exposing fixes only for nvFuser.

shino16 avatar Sep 25 '24 10:09 shino16

I disabled the sanity check for test_litgpt_variants_kvcache.

shino16 avatar Sep 25 '24 11:09 shino16

Note that, before passing the trace to the nvFuser executor, prims.copy_to_out_(t0, out=q) is put after the sdpaex operator thanks to functionalization.

Ugh.

Could it be that inplace_copy_ is particular here?

I wonder if something along the lines of @IvanYashchuk 's planned primitive for dataflow healing would be useful.

t-vi avatar Sep 25 '24 11:09 t-vi

The same happens when we use Tensor.add_ instead of Tensor.index_copy_ too, so this is a generic issue.

My internship period is about to end, so I can no longer spend much time on this issue. Maybe we can close this PR for now and wait for functionalization of copy_ or dependency establishment to mature, and merge #1177 or take other strategies in the meantime if the performance regression in #1173 is urgent.

shino16 avatar Sep 25 '24 11:09 shino16

closing for now @crcrpar please reopen as you see fit.

t-vi avatar Dec 17 '24 12:12 t-vi