Marlin slower than fp16 on larger batches
I have been making some benchmarks with Marlin, but the speed-up is far from what is reported. In fact, it's actually slower than fp16: GPU: A6000 ada
matrix_shape: [11008, 4096]
input_shape: [1, 1024, 11008]
time (fp16): 0.0007191438674926758
time (marlin): 0.0006200861930847168 (1.16x)
input_shape: [16, 1024, 11008]
time (fp16): 0.010448209762573242
time (marlin): 0.01280400848388672 (0.82x)
Code below:
def forward_marlin(marlin_layer, x):
y = torch.empty(x.shape[:-1] + (marlin_layer.s.shape[1],), dtype=x.dtype, device=x.device)
marlin.mul(x.view((-1, x.shape[-1])), marlin_layer.B, y.view((-1, y.shape[-1])), marlin_layer.s, marlin_layer.workspace_fp)
return y
print(time_it(lambda: torch.matmul(x, ref) ))
print(time_it(lambda: forward_marlin(marlin_layer, x)))
What could be the issue ? Thanks in advance!
Hi, Marlin is primarily optimized for generative inference (with a few tokens at-a-time), which is actually memory-bound and can hence be sped up via weight-quantization; e.g. input shapes of (16, 1, 11008). Note that for batchsize > 128 (meaning the overall number of tokens, in your case 16 * 1024), inference stops being memory bound and weight-only quantization can generally not be faster (though Marlin sometimes is a bit for not too large batchsizes due to slightly better partitioning than the default torch kernels).
Thanks for your answer @efrantar . Understood. I am trying to integrate it with our quantization method, below the benchmarks for the forward pass on an 3090, Llama2-7B, batch-size=1, context-size=2048:
fp16: 0.4604 + model compile: 0.4218
int4 (torch compile): 0.4554
Marlin (int4): 0.4221 + model compile: 0.3841
It is about 10% faster than fp16 with this setup-up, the llm eval score drops a bit though (51.57 -> 51.27)
Is there a way to dequantize the weights without calling the matmul with the identity matrix?
Thanks again for your works!