bitsandbytes
bitsandbytes copied to clipboard
FLUTE Integration for Fast Inference
Feature request
Hi, we are big fans of the library and the NF4 data-type, so much so that we have been working on CUDA kernels to speed-up inference for NF4-quantized models (and more). We'd love to explore ways we can integrate FLUTE into the bitsandbytes library.
Motivation
I think FLUTE could make two contributions to bitsandbytes:
- An inference-time kernel for NF4-quantized LLM (and more). For example, this could help extend the current GEMV kernel into GEMM (i.e., batched inference).
- A new quantization algorithm that lightly extends bitsandbytes's NF4 into a learned version.
The kernel and algorithm still have room for improvements. For example, it has limited GPU support, and specializes in specific shapes. These are hopefully not hard limits and could be relaxed. As such, we do not expect FLUTE to be the default option. That being said, we thought it'd be great to have FLUTE be an opt-in feature.
Your contribution
We are happy to submit PRs, though I'm sure there will be some rough edges that we will need some help with.
@HanGuo97 Thank you for reaching out! We're always interested in contributions and would be happy to help with PRs.
It should be not an issue if we take an optimized codepath for certain hardware, shapes, etc and fall back. I'm going to need to take some time to look through your paper and code but in the meantime, feel free to reach out if there are any questions!
Thanks! Do you have suggestions on where we could get started? I was thinking about a few potential ways for integration:
- The easiest point of integration is at the level of
Linear4bit. FLUTE has its own FLUTELinear. If user chooses to use FLUTE as the backend, we could just swapLinear4bitintoFLUTELinear. - Alternatively, we could use
Linear4bit(and henceParams4bit) as it is, and instead expose just FLUTE's GEMM kernel as an option inmatmul_4bit.
Option 1 is easier to implement. FLUTE also has a slightly different implementation of NF beyond 4-bit (NF4, NF3, NF2, as well as their "learned" extensions), and this option makes it easier to expose these features.
Option 2 provides more seemless integration with the rest of bitsandbytes, but it would be a lot more challenging to implement due to various different data formats, among others.
The first option seems like a good place to start, especially since it does not sound like it would prevent later implementation of option 2. That said, I would like @SunMarc and @Titus-von-Koeller to weigh in here in case there's any additional considerations that I may be missing.
I agree with @matthewdouglas, I think it is better to implement a separate Linear class for now. It will enable to be more flexible and iterate faster ! If it makes sense in the future, we can still think about integrating it in Linear4bit.
Agreed, sounds good to me as well. Thanks so much for your contribution, really happy about this collaboration!
Btw, is the code you used for your benchmarking also public somewhere? It's a topic we are looking into ourselves, might be interesting to reference :)
Thanks for the feedback, glad to hear that we have a concrete plan! I'm going to work on it these days, and let you know when I have a PR ready.
As for benchmarking, the code is not public yet, since it has some external dependencies that I haven't got the chance to document. That being said, it's essentially calling PyTorch and Triton's benchmarking functions after preparing the data. Here's the relevant snippet.
Note
- the
outoperand is ignored bybitsandbytesso very strictly speaking, this benchmark includes memory allocation time. Our early experiments suggested this has an overhead of ~2.5%. - we use both PyTorch and Triton's benchmarking tool, but by default we use Triton's timer as it manually flushes the L2 cache.
def prepare_bnb_data(
m: int,
n: int,
k: int,
dtype: torch.dtype,
device: torch.device,
) -> Dict:
Q = BnBParams4bit( # type: ignore
data=torch.randn(
(n, k),
dtype=dtype,
device=device),
requires_grad=False,
blocksize=128,
compress_statistics=False,
quant_type="nf4").cuda()
A = torch.randn(
(m, k),
dtype=dtype,
device=device)
D = torch.empty(
(m, n),
dtype=dtype,
device=device)
torch.cuda.synchronize()
return {
"A": A,
"QT": Q.t(),
"quant_state": Q.quant_state,
"D": D,
}
def benchmark_bnb(data: Dict, n: int = 100) -> Dict[str, float]:
fn = lambda: bnb_matmul_4bit(data["A"], data["QT"], bias=None, quant_state=data["quant_state"], out=data["D"])
timer = torch_benchmark.Timer(
stmt="bnb_matmul_4bit(A, QT, bias=None, quant_state=quant_state, out=D)",
setup="from __main__ import bnb_matmul_4bit",
globals=data)
torch_time = timer.timeit(n).mean
triton_time = triton_benchmark.do_bench(fn, rep=n)
return {"torch_time": torch_time, "triton_time": triton_time}
@HanGuo97 We've discussed this further internally, and thinking from a long-term perspective, it would be ideal for us to be able to build a GEMM kernel in-source here and prefer dispatching to it. After browsing through the flute repo a little more, I can see there's a few challenges to overcome, but I think we can work through it.
Some of the initial things I notice:
- CUTLASS 3.4 dependency is something we should be OK to pick up
- We'd like to decouple from PyTorch and instead expose a C API
- Compile time is quite long, though we'll be able to exclude 3bit, 2bit, and groupsize=32
- Need to tune for a wider variety of GPUs.
Any thoughts or concerns with this?
Thanks for the comments! I'm personally okay with having FLUTE's CUDA code be included inside BNB. As for the list of things you mentioned, many of them are in our (near-term) roadmap as well.
-
The compilation time is long primarily because we instantiate a lot of templates in order to squeeze the last bit of performance. That being said, I think a small subset of them should suffice for practical use cases. As a by product, we are working to remove the specialization on the shapes as well.
-
Adding FLUTE to more GPUs is doable. Currently, FLUTE treats the number of SMs (streaming multiprocessors) of a GPU as a compile-time constant. This can be made as a runtime variable, hence supporting all GPUs in the Ampere and Ada generation GPUs (A-series, and RTX30xx/40xx series). The tuning will still be done on the small subset of GPUs though (since we don't have a ton of resources 😅).
I don't have a specific timeline unfortunately since we are a little bit of bandwidth-bound... but I'm aiming somewhere around next few weeks.
Hello @HanGuo97 , have you benchmarked the latency of Linear4bit and FLUTELinear in different test cases?
@Ther-nullptr we have some benchmarking numbers in the writeup, is that what you were looking for. (The batch size = 32 case might need to be updated though.)
Hi, sorry for the silence. Just want to give a quick update on where we are.
- We pushed an update earlier to support converting
bitsandbytesquantized model into the FLUTE's own format. Since we do not support double quant, we just materialize the first level scales for the moment. - We also added the "Learned Normal Float" quantization algorithm to the repo, a small extension to the existing NF format with some learning.
Two things we are working on these days include:
- We will add support for just-in-time kernel tuning. That is, for unseen shapes, we will run a few kernel templates to select one. This can remove specialization on the shapes at the cost of possibly high initialization time, but is the most scalable way I can think of.
- We will remove specialization on GPUs via querying relevant device properties at runtime (notably, the number of SMs).
Hopefully after these, we are ready to include FLUTE's source code inside BNB per the earlier discussion.
Hi @matthewdouglas,
I think we're almost there with cleaning up the FLUTE kernel for integration into BNB. There are still a few ongoing Python-side API changes, but I don't anticipate any major blockers on BNB side. Here's a quick summary of the updates since our last conversation:
C++/CUDA Side
We removed GPU-specific specialization, so the kernel should (in principle) run on any Ampere GPUs. It should also work on later generations, although there may be room for optimization. These changes are not yet reflected in the public FLUTE repo, but I can push them whenever you'd like to take a look.
Python Side
We're adding a function for just-in-time kernel configuration tuning for shapes + GPUs we didn't tune specifically. We expect to have this ready soon.
Let me know how you'd like to proceed with the integration!
Sorry we really messed this up. This was so very close from being integrated into bitsandbytes and we failed on the bitsandbytes side to go the last mile. This was not because we were not interested in this -- we think FLUTE is great -- we were just a little disorganized and it slipped under our radar. We wanted to pick this up again, but we were overwhelmed with other responsibilities such as the multi-backend refactor which took priority.
At this point, we are still working through a lot of issues that have higher priority and due to shortages in personnel, and due to the time cost of picking this up again and getting it working in the current version of bitsandbytes we are not able to work on this as of now.
We are thinking about picking this up in the future, but right now there is no timeline when this will happen.
We are closing this for now but will reopen if we start working on this. Again, thank you for bringing this so far and sorry for messing this up.
Thanks for the explanation, and no worries at all! As always, we are big fans of the work you all have done and are happy to contribute in the future.