flashinfer icon indicating copy to clipboard operation
flashinfer copied to clipboard

apply_rope_inplace will cause graphbreak due to mutated inputs

Open jianc99 opened this issue 1 year ago • 4 comments

import torch
import flashinfer

rope = flashinfer.apply_rope_inplace

torch.library.define(
     "mylib::target_rope",
     "(Tensor(a!) q, Tensor(a!) k, Tensor indptr, Tensor offsets) -> None",
)
@torch.library.impl("mylib::target_rope", "cuda")
def target_rope(q, k, indptr, offsets):
     rope(q, k, indptr, offsets, interleave=True)

@torch.library.register_fake("mylib::target_rope")
def target_rope_abstract(q, k, indptr, offsets):
     return None

q = torch.randn(4, 4, 128, dtype=torch.bfloat16).to(0)
k = torch.randn(4, 1, 128, dtype=torch.bfloat16).to(0)
indptr = torch.arange(5, dtype=torch.int32).to(0)
offsets = torch.full((4,), 1, dtype=torch.int32).to(0)

torch.compile(torch.ops.mylib.target_rope, mode="reduce-overhead", fullgraph=True)(q, k, indptr, offsets)
skipping cudagraphs due to mutated inputs (2 instances)

jianc99 avatar Jul 28 '24 09:07 jianc99

Mutated arguments have be annotated: https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#creating-mutable-operators

yzh119 avatar Jul 28 '24 18:07 yzh119

I noticed that you already annotated the mutated inputs. I think it's okay to expose another set of apply_rope and apply_llama31_rope which are not inplace operations for pytorch compile.

yzh119 avatar Jul 29 '24 03:07 yzh119

Yeah I have annotated that but it still not works. Exposing non in place rope will be much helpful, thanks!

jianc99 avatar Jul 29 '24 04:07 jianc99

Done in #405 .

yzh119 avatar Jul 29 '24 04:07 yzh119

Done in #405 .

Closing

sricketts avatar Sep 30 '25 01:09 sricketts