Can we integrate Flashinfer into gpt-fast?
Hi, in previous issues, you wrote that you planed to integrate flashinfer into some inference backend like gpt-fast. This will be very interesting! And may I ask can we integrate Flashinfer into gpt-fast now? Thanks
Hi I don't see there any difficulty of integrating flashinfer into gpt-fast. But we prefer a minimal example (e.g. within 1k loc) of continuous batching (gptfast don't support batching afaik), we are working with @sgl-project team on that.
Wow cool! Is there any example of using torch.compile with flashinfer BatchPrefillWithPagedKVCacheWrapper or Decode wrapper? Thanks!
You can check how vllm/sglang integrates flashinfer (they are good examples of how to use those wrapper functions in flashinfer), both of them uses pytorch's cudagraph capturing.
I haven't tried to use torch.compile together with flashinfer (please let me know if you have related experience), the only possible issue is whether torch compiler can deal with custom operators, we should follow pytorch's custom operator manual for better compatibility in the future.
Thanks for your information!
I have checked how to define custom operators, and I can successfully define single_prefill_with_kv_cache like below:
import torch
import flashinfer
torch.library.define(
"mylib::custom_func",
"(Tensor q, Tensor k_cache, Tensor v_cache) -> Tensor",
)
@torch.library.impl("mylib::custom_func", "cuda")
def custom_func(q, k_cache, v_cache):
return flashinfer.single_prefill_with_kv_cache(
q, k_cache, v_cache
)
@torch.library.register_fake("mylib::custom_func")
def custom_func_abstract(q, k_cache, v_cache):
return torch.empty_like(q)
with torch.device("cuda"):
q = torch.randn((2, 2, 128), dtype=torch.bfloat16)
k_cache = torch.randn((5, 2, 128), dtype=torch.bfloat16)
v_cache = torch.randn((5, 2, 128), dtype=torch.bfloat16)
torch.compile(torch.ops.mylib.custom_func, fullgraph=True)(
q, k_cache, v_cache
)
The problem is when using BatchPrefillWithPagedKVCacheWrapper, we have to first init the wrapper, and then perform the forward pass. This makes this kind of registration for the forward function difficult.
Thank you for the info, I think we can resolve this by making those wrappers pure python object. I'll refactor the codebase this weekend.
Thanks! I think I solved this problem by creating the wrapper before defining custom operator, and keeping using this wrapper. But make the wrapper python project will be fine and make things easier!
Btw, llama3.1 family have different positional encoding (rope scaling with two factors) compared with llama2 and llama3. Can we support llama3.1 in the next flashinfer? Thanks!
by creating the wrapper before defining custom operator
Ideally, the begin_forward and forward functions should be registered as custom operators as well, so it's preferable to make wrapper python objects so that we can make sure all of these function arguments are torch tensors.
different positional encoding
Yes it's easy to support, just stay tuned.
Cool!
@jianc99 flashinfer + torch.compile is supported in sglang and it is very fast.
You can try
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B --enable-torch-compile
It is faster than the original gpt-fast, TensorRT-LLM, much faster than vLLM. It also supports all online serving features such as dynamic batching and prefix caching.