composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

WIP: Attn bwd prototype 2 (reload QKV)

Open rosenrodt opened this issue 3 years ago • 0 comments

Bare minimum batched multihead attention backward kernel. Many missing functionalities:

  • ~alpha(QK) scaling~ implemented
  • masking
  • dropout

Some quirks that need to be ironed out too. Eg:

  • A/B/B1/C tensor sometimes mean Q/K/V/Y tensors
  • Currently exposed tuning parameter is the same as attention fwd; not sure what we should expose now with the added complexity
  • Some sizes / init method can report validation failure; not sure if it is a bug or fp16 quantization error
  • Higher than expected register spills; given 128x128x32 tile size the initial estimate is 192 accumulator VGPRs + some auxiliary VGPRs for other uses, but actual budget exceeds 256 VGPRs and spills quite a lot into global memory

Example output

$ CXX=hipcc cmake . -B build -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=/opt/rocm -DAMDGPU_TARGETS=gfx90a
$ cmake --build build -t example_batched_multihead_attention_backward_fp16
$ build/bin/example_batched_multihead_attention_backward_fp16
q_gs_ms_ks: dim 4, lengths {3, 2, 512, 128}, strides {131072, 65536, 128, 1}
k_gs_ns_ks: dim 4, lengths {3, 2, 512, 128}, strides {131072, 65536, 128, 1}
v_gs_os_ns: dim 4, lengths {3, 2, 128, 512}, strides {131072, 65536, 1, 128}
y_gs_ms_os: dim 4, lengths {3, 2, 512, 128}, strides {131072, 65536, 128, 1}
lse_gs_ms_os: dim 3, lengths {3, 2, 512}, strides {1024, 512, 1}
launch_and_time_kernel: grid_dim {24, 1, 1}, block_dim {256, 1, 1} 
Warm up 1 time
Start running 10 times...
Perf: 0.365166 ms, 5.51328 TFlops, 17.2627 GB/s, DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<256, 128, 128, 32, 8, 8, 128, 128, 64, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
Checking qgrad:
Checking kgrad:
Checking vgrad:
pass

rosenrodt avatar Dec 27 '22 11:12 rosenrodt