flashinfer
flashinfer copied to clipboard
apply_rope_inplace will cause graphbreak due to mutated inputs
import torch
import flashinfer
rope = flashinfer.apply_rope_inplace
torch.library.define(
"mylib::target_rope",
"(Tensor(a!) q, Tensor(a!) k, Tensor indptr, Tensor offsets) -> None",
)
@torch.library.impl("mylib::target_rope", "cuda")
def target_rope(q, k, indptr, offsets):
rope(q, k, indptr, offsets, interleave=True)
@torch.library.register_fake("mylib::target_rope")
def target_rope_abstract(q, k, indptr, offsets):
return None
q = torch.randn(4, 4, 128, dtype=torch.bfloat16).to(0)
k = torch.randn(4, 1, 128, dtype=torch.bfloat16).to(0)
indptr = torch.arange(5, dtype=torch.int32).to(0)
offsets = torch.full((4,), 1, dtype=torch.int32).to(0)
torch.compile(torch.ops.mylib.target_rope, mode="reduce-overhead", fullgraph=True)(q, k, indptr, offsets)
skipping cudagraphs due to mutated inputs (2 instances)
Mutated arguments have be annotated: https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#creating-mutable-operators
I noticed that you already annotated the mutated inputs.
I think it's okay to expose another set of apply_rope and apply_llama31_rope which are not inplace operations for pytorch compile.
Yeah I have annotated that but it still not works. Exposing non in place rope will be much helpful, thanks!
Done in #405 .