llm-awq icon indicating copy to clipboard operation
llm-awq copied to clipboard

[Question/Feature] Fused attention/mlp/norm for MPT

Open casper-hansen opened this issue 2 years ago • 3 comments

I have had the great pleasure of testing out TinyChat today - it's blazing fast.

In particular, I was able to get 102 tokens/s (9.8ms/token) on a 4090 with the fused operations on LLaMa-2 7B, which is a 100% speed boost over the non-fused operations which ran at about 45-50 tokens/s.

How can we extend these fusing operations to the MPT model series? i.e. fusing the torch implementation of Multi-Head Attention plus their ALiBi implementation.

The main reason I want to use MPT models over LLaMa is licensing issues, but also that MPT has 7B models trained for 8k context.

casper-hansen avatar Jul 26 '23 21:07 casper-hansen

It seems TinyChat is currently very CPU-bound for all other models than LLaMa. On A6000, 3090, 4090 with AMD EPYC 7-Series CPU, performance is largely the same due to low single-threaded performance of the CPU. However, if the CPU is upgraded to an i9-13900k (roughly double the performance of AMD CPU), the performance also gets a 100% boost.

@Sakits Any plans for adding further speedups for TinyChat to make it less CPU-bound? I see the fused/optimized layers for LLaMa-2 helped with utilizing the GPU more.

Rough expectations for speedup:

  • MLP: 0.5-1.0ms
  • LayerNorm: ~3ms
  • Attention: ~7ms

If all parts are optimized, we should see below 10ms inference per token, even on slower CPUs. Could even get close to 5-6ms on better GPUs if TinyChat was optimized further.

I got about 0.5-1.0ms speedup (2.7%-5.5% speedup) by replacing the linear layers of MPT. You can see my fork/branch here.

class QuantMPTMLP(nn.Module):
    def __init__(
        self,
        up_proj,
        act,
        down_proj
    ):
        super().__init__()
        self.register_buffer('up_proj_qweight', up_proj.qweight)
        self.register_buffer('up_proj_scales', up_proj.scales)
        self.register_buffer('up_proj_qzeros', up_proj.qzeros)

        self.up_proj = up_proj
        self.act = act
        self.down_proj = down_proj

    def forward(self, x: torch.Tensor):
        x = x.reshape(-1, x.shape[-1])
        x = awq_inference_engine.gemm_forward_cuda(x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8)

        return self.down_proj(self.act(x))

casper-hansen avatar Jul 31 '23 13:07 casper-hansen

Hi @casperbh96,

Thank you for your suggestions and contributions!

The current version of AWQ library mainly focuses on usability, hence it hasn't been fully optimized for speed. However, we're planning a reimplementation based on a more efficient baseline (e.g. TGI). Please stay tuned for future updates! :)

Sakits avatar Jul 31 '23 14:07 Sakits

TGI

That sounds great! :)

Only thing to keep in mind is that TGI has recently switched license, so be careful if you plan to use their code.

Edit: Looks like you can still use TGI commercially for 90%+ of use-cases, so might be still be a good idea with TGI. https://github.com/huggingface/text-generation-inference/issues/744

casper-hansen avatar Jul 31 '23 14:07 casper-hansen