quanto
quanto copied to clipboard
Corrupted outputs with Marlin int4 kernels as parallelization increases
When using MarlinInt4WeightQBitsTensor and its associated optimized gemm kernel, there are issues with the weight/scales/zero-point readback as soon as parallelization increases.
The consequence is that output features higher than 128 are corrupted when a sufficient amount of inputs are parallelized.
Test to reproduce the issue here: https://github.com/huggingface/optimum-quanto/blob/852bb9cb6fb707a6fcebff7e068dc6bbdda779cb/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py#L134