llm-awq
llm-awq copied to clipboard
[Question/Feature] Fused attention/mlp/norm for MPT
I have had the great pleasure of testing out TinyChat today - it's blazing fast.
In particular, I was able to get 102 tokens/s (9.8ms/token) on a 4090 with the fused operations on LLaMa-2 7B, which is a 100% speed boost over the non-fused operations which ran at about 45-50 tokens/s.
How can we extend these fusing operations to the MPT model series? i.e. fusing the torch implementation of Multi-Head Attention plus their ALiBi implementation.
The main reason I want to use MPT models over LLaMa is licensing issues, but also that MPT has 7B models trained for 8k context.
It seems TinyChat is currently very CPU-bound for all other models than LLaMa. On A6000, 3090, 4090 with AMD EPYC 7-Series CPU, performance is largely the same due to low single-threaded performance of the CPU. However, if the CPU is upgraded to an i9-13900k (roughly double the performance of AMD CPU), the performance also gets a 100% boost.
@Sakits Any plans for adding further speedups for TinyChat to make it less CPU-bound? I see the fused/optimized layers for LLaMa-2 helped with utilizing the GPU more.
Rough expectations for speedup:
- MLP: 0.5-1.0ms
- LayerNorm: ~3ms
- Attention: ~7ms
If all parts are optimized, we should see below 10ms inference per token, even on slower CPUs. Could even get close to 5-6ms on better GPUs if TinyChat was optimized further.
I got about 0.5-1.0ms speedup (2.7%-5.5% speedup) by replacing the linear layers of MPT. You can see my fork/branch here.
class QuantMPTMLP(nn.Module):
def __init__(
self,
up_proj,
act,
down_proj
):
super().__init__()
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.up_proj = up_proj
self.act = act
self.down_proj = down_proj
def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemm_forward_cuda(x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8)
return self.down_proj(self.act(x))
Hi @casperbh96,
Thank you for your suggestions and contributions!
The current version of AWQ library mainly focuses on usability, hence it hasn't been fully optimized for speed. However, we're planning a reimplementation based on a more efficient baseline (e.g. TGI). Please stay tuned for future updates! :)
TGI
That sounds great! :)
Only thing to keep in mind is that TGI has recently switched license, so be careful if you plan to use their code.
Edit: Looks like you can still use TGI commercially for 90%+ of use-cases, so might be still be a good idea with TGI. https://github.com/huggingface/text-generation-inference/issues/744