marlin icon indicating copy to clipboard operation
marlin copied to clipboard

Marlin slower than fp16 on larger batches

Open mobicham opened this issue 1 year ago • 2 comments

I have been making some benchmarks with Marlin, but the speed-up is far from what is reported. In fact, it's actually slower than fp16: GPU: A6000 ada

matrix_shape:  [11008, 4096]

input_shape: [1, 1024, 11008]
time (fp16): 0.0007191438674926758
time (marlin): 0.0006200861930847168 (1.16x)

input_shape: [16, 1024, 11008]
time (fp16): 0.010448209762573242
time (marlin): 0.01280400848388672 (0.82x)

Code below:

def forward_marlin(marlin_layer, x):
    y = torch.empty(x.shape[:-1] + (marlin_layer.s.shape[1],), dtype=x.dtype, device=x.device)
    marlin.mul(x.view((-1, x.shape[-1])), marlin_layer.B, y.view((-1, y.shape[-1])), marlin_layer.s, marlin_layer.workspace_fp)
    return y

print(time_it(lambda: torch.matmul(x, ref) ))
print(time_it(lambda: forward_marlin(marlin_layer, x)))

What could be the issue ? Thanks in advance!

mobicham avatar Apr 09 '24 11:04 mobicham

Hi, Marlin is primarily optimized for generative inference (with a few tokens at-a-time), which is actually memory-bound and can hence be sped up via weight-quantization; e.g. input shapes of (16, 1, 11008). Note that for batchsize > 128 (meaning the overall number of tokens, in your case 16 * 1024), inference stops being memory bound and weight-only quantization can generally not be faster (though Marlin sometimes is a bit for not too large batchsizes due to slightly better partitioning than the default torch kernels).

efrantar avatar Apr 09 '24 22:04 efrantar

Thanks for your answer @efrantar . Understood. I am trying to integrate it with our quantization method, below the benchmarks for the forward pass on an 3090, Llama2-7B, batch-size=1, context-size=2048:

fp16:                 0.4604 + model compile: 0.4218
int4 (torch compile): 0.4554
Marlin (int4):        0.4221 + model compile: 0.3841

It is about 10% faster than fp16 with this setup-up, the llm eval score drops a bit though (51.57 -> 51.27)

Is there a way to dequantize the weights without calling the matmul with the identity matrix?

Thanks again for your works!

mobicham avatar Apr 10 '24 08:04 mobicham