lorax icon indicating copy to clipboard operation
lorax copied to clipboard

Fuse q,k,v LoRAs

Open tgaddair opened this issue 1 year ago • 4 comments

Currently, we treat each of the Q, K, V LoRAs as distinct tensors, meaning we do 3 SGMV calls per layer instead of 1. We should fuse them to improve batching.

tgaddair avatar Jan 04 '24 17:01 tgaddair

Quick sanity check on the math:

import torch

h1 = 32
h2 = 32
h3 = 8
h4 = 8
r = 8
b = 4

x = torch.randn((b, h1))

qA = torch.randn((h1, r))
qB = torch.randn((r, h2))

kA = torch.randn((h1, r))
kB = torch.randn((r, h3))

vA = torch.randn((h1, r))
vB = torch.randn((r, h4))

y_q = (x @ qA) @ qB
y_k = (x @ kA) @ kB
y_v = (x @ vA) @ vB
y = torch.cat([y_q, y_k, y_v], dim=1)
print(y, y.shape)

A = torch.zeros((h1, r * 3))
B = torch.zeros((r * 3, h2 + h3 + h4))

A[:, 0:r] = qA
A[:, r:r*2] = kA
A[:, r*2:r*3] = vA

B[0:r, 0:h2] = qB
B[r:r*2, h2:h2+h3] = kB
B[r*2:r*3, h2+h3:h2+h3+h4] = vB

print(A.shape, B.shape)

y2 = (x @ A) @ B
print(y2, y2.shape)

torch.allclose(y, y2)

Everything looks good. There is some increased memory overhead due to needing to pad the B tensor with zeros:

elems1 = sum(v.numel() for v in [qA, qB, kA, kB, vA, vB])
elems2 = sum(v.numel() for v in [A, B])
print(elems1, elems2, elems2 / elems1, elems1 / elems2)

We get about a 67% increase in memory overhead, so we may want to make this optional.

Performance difference with Mistral-7B:

64 tokens generated, 1x A100

rank 8 (q, v):

  • baseline: 1.122s
  • fused: 1.073s latency reduction: ~5%

rank 16 (q, k, v):

  • baseline: 1.233s
  • fused: 1.132s latency reduction: ~8%

Latency reduction, particularly when using all 3 of q, k, v, is meaningful but not clear it's worth the 67% increase in memory usage for the adapter.

Also, it looks like there are numerical challenges with the SGMV kernel that are causing some corruption with these fused ranks. We'll need to resolve those to get this working.

tgaddair avatar Jan 05 '24 18:01 tgaddair

Branch: https://github.com/predibase/lorax/tree/fuse-qkv

tgaddair avatar Jan 05 '24 18:01 tgaddair

Due to the numerical issues, we could revisit this after tackling #160, which will allow us to pad SGMV ops to a particular (supported) rank.

tgaddair avatar Jan 05 '24 18:01 tgaddair

It seems that the AWQ quantized model already support the fused qkv, might consider it also.

thincal avatar Apr 08 '24 19:04 thincal