lightseq icon indicating copy to clipboard operation
lightseq copied to clipboard

Question about qkv_weight initialization

Open kimbaol opened this issue 3 years ago • 1 comments

I'm confused about the qkv weight initialization in lightseq. If qkv_w is initialized from the existed weights, there are two ways to init it: (assume qkv shape is [m,k], and q/k/v weight's shape is [k,n])

  1. concat dim n, and perform matmul without trans_b qkv_w = torch.cat([q_w, k_w, v_w], dim=1) ##[k,3n] qkv_out = torch.matmul(qkv, qkv_w) ##[m,3n]
  2. stack dim k, and perform batched_gemm with broadcast in batch dim qkv_w = torch.stack([q_w, k_w, v_w], dim=0) ##[3,k,n] qkv_out = torch.matmul(qkv, qkv_w) ##[3,m,n]

But in lightseq, the compute logic seems as follows: qkv_w = torch.cat([q_w, k_w, v_w], dim=0) ##[3k,n] qkv_out = cublas_gemm(qkv, qkv_w, trans_b=True) ##[m, 3n] It seems lightseq want to perform [m,k] * [3n,k] (trans_b=True) -> [m,3n], but in fact, qkv_w's shape is [3k,n]

Correct me if I'm wrong. Thanks!

kimbaol avatar Aug 19 '21 10:08 kimbaol

In fact, the calculation process of lightseq is the method (1) you describe. Specifically, qkv_w in lightseq is a 1-D array in format [q_1 k_1 v_1 ... q_i, k_i, v_i, ...] where x_i is the i-th row of q/k/v. Maybe one thing needs attention, cublas_gemm read matrix in column-major: https://en.wikipedia.org/wiki/Row-_and_column-major_order

neopro12 avatar Aug 19 '21 13:08 neopro12