lightseq
lightseq copied to clipboard
Question about qkv_weight initialization
I'm confused about the qkv weight initialization in lightseq. If qkv_w is initialized from the existed weights, there are two ways to init it: (assume qkv shape is [m,k], and q/k/v weight's shape is [k,n])
- concat dim n, and perform matmul without trans_b
qkv_w = torch.cat([q_w, k_w, v_w], dim=1) ##[k,3n]
qkv_out = torch.matmul(qkv, qkv_w) ##[m,3n]
- stack dim k, and perform batched_gemm with broadcast in batch dim
qkv_w = torch.stack([q_w, k_w, v_w], dim=0) ##[3,k,n]
qkv_out = torch.matmul(qkv, qkv_w) ##[3,m,n]
But in lightseq, the compute logic seems as follows:
qkv_w = torch.cat([q_w, k_w, v_w], dim=0) ##[3k,n]
qkv_out = cublas_gemm(qkv, qkv_w, trans_b=True) ##[m, 3n]
It seems lightseq want to perform [m,k] * [3n,k] (trans_b=True) -> [m,3n], but in fact, qkv_w's shape is [3k,n]
Correct me if I'm wrong. Thanks!
In fact, the calculation process of lightseq is the method (1) you describe. Specifically, qkv_w in lightseq is a 1-D array in format [q_1 k_1 v_1 ... q_i, k_i, v_i, ...] where x_i is the i-th row of q/k/v. Maybe one thing needs attention, cublas_gemm read matrix in column-major: https://en.wikipedia.org/wiki/Row-_and_column-major_order