Paddle
Paddle copied to clipboard
[PHI] add int4 weight only quant kernel, add int4 weight only permute kernel
PR Category
Inference
PR Types
New features
Description
给paddle添加int4量化的kernel和int4量化进行permute的kernel。
TL;DR
支持了一个GPU kernel,它能做int4 weight only量化的工作。并且能支持weight_only_linear (同时也能和反量化接口对齐,如果你想单纯做量化反量化看看。你可以这么执行代码)
import paddle
x = paddle.randn(shape=[4096, 2048], dtype=paddle.float16)
qt, scale = paddle.nn.quant.weight_quantize(x, algo='weight_only_int4')
## 啊 paddle暂时还不可以形状推导。 但是PR已经在合了
## view之前的shape应该是[1024, 4096],这个shape是做weight only linear用的。后续也可以加一个接口判断是否矩阵乘法来判断是否在c++侧reshape
qt = qt.view([2048, 2048])
x_dq = paddle.nn.quant.weight_dequantize(qt, scale, algo='weight_only_int4')
当然,weight only linear也是支持的
import paddle
from paddle.nn.quant import weight_only_linear, weight_quantize, weight_dequantize
x = paddle.rand(shape=(2, 4096), dtype='float16')
weight = paddle.randn(shape=(4096, 2048), dtype='float32')
weight = weight.astype('float16')
quant_weight, quant_scale = weight_quantize(x=weight, algo='weight_only_int4')
quant_out = weight_only_linear(x=quant_x, weight=quant_weight, weight_scale=quant_scale, weight_dtype="int4")
## 能和它大概对齐吧,毕竟int4量化的精度低的离谱 out = paddle.matmul(x=x, y=weight)
int4 weight only quant总结
参考CPU的实现,SM70以上kernel的实现分几个步骤:
- 按行进行pack(2int4pack成一个int8)
- permute_B_rows_for_mixed_gemm:排布列方向的元素
- subbyte_transpose:把列主序的weight变成行主序的,并且由按行进行pack转化成按列进行pack。
- interleave_column_major_tensor:每64个元素进行interleave
- add_bias_and_interleave_int4s_inplace:把int8转换成uint8(+8)
但是我们其实不需要这么复杂的实现,我们可以直接就按列进行pack。也能达到一样的效果。并且只需要两个kernel(加上量化需要三个kernel)。方法如下:
int4量化kernel
对于int4量化来说,我们分别实现了按行pack和按列pack。(为了让SM70版本的显卡也能正常工作QAQ) 对按列pack来说,它需要让两个int4pack成一个int8的数进行实现。在代码里,我们让上下两行组成一个int8的数,也就是按列进行的pack。
int4 permute kernel
对于int4量化,我们需要对输入数据进行重排来适配cutlass的快速反量化kernel。 在int4反量化端,我们观察反量化算子实现可以发现。最后所需的输出是:
0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27
4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
参考cutlass的快速反量化实现。 int4快速反量化4个int8一组,能把int8的数据转换为fp16的。但它会改变数据的排布:
0 2 4 6 1 3 5 7 -> 0 1 2 3 4 5 6 7
则我们可以推得在快速反量化之前,我们需要的数据是
// 0 8 16 24 1 9 17 25 2 10 18 26 3 11 19 27
// 4 12 20 28 5 13 21 29 6 14 22 30 7 15 23 31
上面一组数看上去没有任何的规律,但是我们可以给它做一点小小的调整,调整成下面的形式,只需要一些简单的位运算即可
// 0 1 16 17 8 9 24 25 2 3 18 19 10 11 26 27
// 4 5 20 21 12 13 28 29 6 7 22 23 14 15 30 31
我们知道,两个int4 pack成了一个int8,我们也可以把上面的数调整成int8的index
0 8 4 12 1 9 5 13 2 10 6 14 3 11 7 15
那么从
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 -> 0 8 4 12 1 9 5 13 2 10 6 14 3 11 7 15
的坐标为
0 4 8 12 2 6 10 14 1 5 9 13 3 7 11 15
得到这个新的permute_kk(代码里的变量,描述列之间的permute),可以通过int8的permute_kk做一点小小的改变 从int8 permute转换为int4 permute int8
0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
可以把它变成
0 2 4 6 1 3 5 7 8 10 12 14 9 11 13 15
% 8 * 2
0 4 8 12 2 6 10 14 0 4 8 12 2 6 10 14
add 1 for 0 4 8 12 2 6 10 14 [0 4 8 12 2 6 10 14]
简单的位运算kernel(最后执行)
从
// (0 1) (16 17) (8 9) (24 25) (2 3) (18 19) (10 11) (26 27)
// (4 5) (20 21) (12 13) (28 29) (6 7) (22 23) (14 15) (30 31)
到
// 0 8 16 24 1 9 17 25 2 10 18 26 3 11 19 27
// 4 12 20 28 5 13 21 29 6 14 22 30 7 15 23 31
我们可以每四个数一组,然后02 13 之间做低四位和高四位的交换即可。
int4 row interleave
对于int8的case,代码在相邻的两行中,每64个元素进行交织。但是对于int4的情况。代码就会在相邻的四行中,每32个元素进行交织。所以在permute的处理时,写成了
int permute_index = permute_kk % 32 + permute_kk / 32 * 128 +
32 * (n_id % 4) + total_k * 4 * (n_id / 4);
这样也符合预期。(写着写着天都亮了zzz)