bitsandbytes icon indicating copy to clipboard operation
bitsandbytes copied to clipboard

FLUTE Integration for Fast Inference

Open HanGuo97 opened this issue 1 year ago • 12 comments

Feature request

Hi, we are big fans of the library and the NF4 data-type, so much so that we have been working on CUDA kernels to speed-up inference for NF4-quantized models (and more). We'd love to explore ways we can integrate FLUTE into the bitsandbytes library.

Motivation

I think FLUTE could make two contributions to bitsandbytes:

  • An inference-time kernel for NF4-quantized LLM (and more). For example, this could help extend the current GEMV kernel into GEMM (i.e., batched inference).
  • A new quantization algorithm that lightly extends bitsandbytes's NF4 into a learned version.

The kernel and algorithm still have room for improvements. For example, it has limited GPU support, and specializes in specific shapes. These are hopefully not hard limits and could be relaxed. As such, we do not expect FLUTE to be the default option. That being said, we thought it'd be great to have FLUTE be an opt-in feature.

Your contribution

We are happy to submit PRs, though I'm sure there will be some rough edges that we will need some help with.

HanGuo97 avatar Jul 25 '24 14:07 HanGuo97

@HanGuo97 Thank you for reaching out! We're always interested in contributions and would be happy to help with PRs.

It should be not an issue if we take an optimized codepath for certain hardware, shapes, etc and fall back. I'm going to need to take some time to look through your paper and code but in the meantime, feel free to reach out if there are any questions!

matthewdouglas avatar Jul 31 '24 22:07 matthewdouglas

Thanks! Do you have suggestions on where we could get started? I was thinking about a few potential ways for integration:

  1. The easiest point of integration is at the level of Linear4bit. FLUTE has its own FLUTELinear. If user chooses to use FLUTE as the backend, we could just swap Linear4bit into FLUTELinear.
  2. Alternatively, we could use Linear4bit (and hence Params4bit) as it is, and instead expose just FLUTE's GEMM kernel as an option in matmul_4bit.

Option 1 is easier to implement. FLUTE also has a slightly different implementation of NF beyond 4-bit (NF4, NF3, NF2, as well as their "learned" extensions), and this option makes it easier to expose these features.

Option 2 provides more seemless integration with the rest of bitsandbytes, but it would be a lot more challenging to implement due to various different data formats, among others.

HanGuo97 avatar Aug 03 '24 02:08 HanGuo97

The first option seems like a good place to start, especially since it does not sound like it would prevent later implementation of option 2. That said, I would like @SunMarc and @Titus-von-Koeller to weigh in here in case there's any additional considerations that I may be missing.

matthewdouglas avatar Aug 12 '24 14:08 matthewdouglas

I agree with @matthewdouglas, I think it is better to implement a separate Linear class for now. It will enable to be more flexible and iterate faster ! If it makes sense in the future, we can still think about integrating it in Linear4bit.

SunMarc avatar Aug 12 '24 14:08 SunMarc

Agreed, sounds good to me as well. Thanks so much for your contribution, really happy about this collaboration!

Btw, is the code you used for your benchmarking also public somewhere? It's a topic we are looking into ourselves, might be interesting to reference :)

Titus-von-Koeller avatar Aug 14 '24 17:08 Titus-von-Koeller

Thanks for the feedback, glad to hear that we have a concrete plan! I'm going to work on it these days, and let you know when I have a PR ready.

As for benchmarking, the code is not public yet, since it has some external dependencies that I haven't got the chance to document. That being said, it's essentially calling PyTorch and Triton's benchmarking functions after preparing the data. Here's the relevant snippet.

Note

  • the out operand is ignored by bitsandbytes so very strictly speaking, this benchmark includes memory allocation time. Our early experiments suggested this has an overhead of ~2.5%.
  • we use both PyTorch and Triton's benchmarking tool, but by default we use Triton's timer as it manually flushes the L2 cache.

def prepare_bnb_data(
    m: int,
    n: int,
    k: int,
    dtype: torch.dtype,
    device: torch.device,
) -> Dict:

    Q = BnBParams4bit(  # type: ignore
        data=torch.randn(
            (n, k),
            dtype=dtype,
            device=device),
        requires_grad=False,
        blocksize=128,
        compress_statistics=False,
        quant_type="nf4").cuda()

    A = torch.randn(
        (m, k),
        dtype=dtype,
        device=device)

    D = torch.empty(
        (m, n),
        dtype=dtype,
        device=device)

    torch.cuda.synchronize()

    return {
        "A": A,
        "QT": Q.t(),
        "quant_state": Q.quant_state,
        "D": D,
    }


def benchmark_bnb(data: Dict, n: int = 100) -> Dict[str, float]:

    fn = lambda: bnb_matmul_4bit(data["A"], data["QT"], bias=None, quant_state=data["quant_state"], out=data["D"])
    timer = torch_benchmark.Timer(
        stmt="bnb_matmul_4bit(A, QT, bias=None, quant_state=quant_state, out=D)",
        setup="from __main__ import bnb_matmul_4bit",
        globals=data)

    torch_time = timer.timeit(n).mean
    triton_time = triton_benchmark.do_bench(fn, rep=n)
    return {"torch_time": torch_time, "triton_time": triton_time}

HanGuo97 avatar Aug 14 '24 18:08 HanGuo97

@HanGuo97 We've discussed this further internally, and thinking from a long-term perspective, it would be ideal for us to be able to build a GEMM kernel in-source here and prefer dispatching to it. After browsing through the flute repo a little more, I can see there's a few challenges to overcome, but I think we can work through it.

Some of the initial things I notice:

  • CUTLASS 3.4 dependency is something we should be OK to pick up
  • We'd like to decouple from PyTorch and instead expose a C API
  • Compile time is quite long, though we'll be able to exclude 3bit, 2bit, and groupsize=32
  • Need to tune for a wider variety of GPUs.

Any thoughts or concerns with this?

matthewdouglas avatar Aug 16 '24 21:08 matthewdouglas

Thanks for the comments! I'm personally okay with having FLUTE's CUDA code be included inside BNB. As for the list of things you mentioned, many of them are in our (near-term) roadmap as well.

  • The compilation time is long primarily because we instantiate a lot of templates in order to squeeze the last bit of performance. That being said, I think a small subset of them should suffice for practical use cases. As a by product, we are working to remove the specialization on the shapes as well.

  • Adding FLUTE to more GPUs is doable. Currently, FLUTE treats the number of SMs (streaming multiprocessors) of a GPU as a compile-time constant. This can be made as a runtime variable, hence supporting all GPUs in the Ampere and Ada generation GPUs (A-series, and RTX30xx/40xx series). The tuning will still be done on the small subset of GPUs though (since we don't have a ton of resources 😅).

I don't have a specific timeline unfortunately since we are a little bit of bandwidth-bound... but I'm aiming somewhere around next few weeks.

HanGuo97 avatar Aug 17 '24 01:08 HanGuo97

Hello @HanGuo97 , have you benchmarked the latency of Linear4bit and FLUTELinear in different test cases?

Ther-nullptr avatar Aug 17 '24 02:08 Ther-nullptr

@Ther-nullptr we have some benchmarking numbers in the writeup, is that what you were looking for. (The batch size = 32 case might need to be updated though.)

HanGuo97 avatar Aug 17 '24 18:08 HanGuo97

Hi, sorry for the silence. Just want to give a quick update on where we are.

  1. We pushed an update earlier to support converting bitsandbytes quantized model into the FLUTE's own format. Since we do not support double quant, we just materialize the first level scales for the moment.
  2. We also added the "Learned Normal Float" quantization algorithm to the repo, a small extension to the existing NF format with some learning.

Two things we are working on these days include:

  1. We will add support for just-in-time kernel tuning. That is, for unseen shapes, we will run a few kernel templates to select one. This can remove specialization on the shapes at the cost of possibly high initialization time, but is the most scalable way I can think of.
  2. We will remove specialization on GPUs via querying relevant device properties at runtime (notably, the number of SMs).

Hopefully after these, we are ready to include FLUTE's source code inside BNB per the earlier discussion.

HanGuo97 avatar Aug 31 '24 21:08 HanGuo97

Hi @matthewdouglas,

I think we're almost there with cleaning up the FLUTE kernel for integration into BNB. There are still a few ongoing Python-side API changes, but I don't anticipate any major blockers on BNB side. Here's a quick summary of the updates since our last conversation:

C++/CUDA Side

We removed GPU-specific specialization, so the kernel should (in principle) run on any Ampere GPUs. It should also work on later generations, although there may be room for optimization. These changes are not yet reflected in the public FLUTE repo, but I can push them whenever you'd like to take a look.

Python Side

We're adding a function for just-in-time kernel configuration tuning for shapes + GPUs we didn't tune specifically. We expect to have this ready soon.

Let me know how you'd like to proceed with the integration!

HanGuo97 avatar Oct 05 '24 15:10 HanGuo97

Sorry we really messed this up. This was so very close from being integrated into bitsandbytes and we failed on the bitsandbytes side to go the last mile. This was not because we were not interested in this -- we think FLUTE is great -- we were just a little disorganized and it slipped under our radar. We wanted to pick this up again, but we were overwhelmed with other responsibilities such as the multi-backend refactor which took priority.

At this point, we are still working through a lot of issues that have higher priority and due to shortages in personnel, and due to the time cost of picking this up again and getting it working in the current version of bitsandbytes we are not able to work on this as of now.

We are thinking about picking this up in the future, but right now there is no timeline when this will happen.

TimDettmers avatar Feb 28 '25 15:02 TimDettmers

We are closing this for now but will reopen if we start working on this. Again, thank you for bringing this so far and sorry for messing this up.

TimDettmers avatar Feb 28 '25 15:02 TimDettmers

Thanks for the explanation, and no worries at all! As always, we are big fans of the work you all have done and are happy to contribute in the future.

HanGuo97 avatar Feb 28 '25 16:02 HanGuo97