composable_kernel
composable_kernel copied to clipboard
WIP: Attn bwd prototype 2 (reload QKV)
Bare minimum batched multihead attention backward kernel. Many missing functionalities:
- ~alpha(QK) scaling~ implemented
- masking
- dropout
Some quirks that need to be ironed out too. Eg:
- A/B/B1/C tensor sometimes mean Q/K/V/Y tensors
- Currently exposed tuning parameter is the same as attention fwd; not sure what we should expose now with the added complexity
- Some sizes / init method can report validation failure; not sure if it is a bug or fp16 quantization error
- Higher than expected register spills; given 128x128x32 tile size the initial estimate is 192 accumulator VGPRs + some auxiliary VGPRs for other uses, but actual budget exceeds 256 VGPRs and spills quite a lot into global memory
Example output
$ CXX=hipcc cmake . -B build -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=/opt/rocm -DAMDGPU_TARGETS=gfx90a
$ cmake --build build -t example_batched_multihead_attention_backward_fp16
$ build/bin/example_batched_multihead_attention_backward_fp16
q_gs_ms_ks: dim 4, lengths {3, 2, 512, 128}, strides {131072, 65536, 128, 1}
k_gs_ns_ks: dim 4, lengths {3, 2, 512, 128}, strides {131072, 65536, 128, 1}
v_gs_os_ns: dim 4, lengths {3, 2, 128, 512}, strides {131072, 65536, 1, 128}
y_gs_ms_os: dim 4, lengths {3, 2, 512, 128}, strides {131072, 65536, 128, 1}
lse_gs_ms_os: dim 3, lengths {3, 2, 512}, strides {1024, 512, 1}
launch_and_time_kernel: grid_dim {24, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 0.365166 ms, 5.51328 TFlops, 17.2627 GB/s, DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<256, 128, 128, 32, 8, 8, 128, 128, 64, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
Checking qgrad:
Checking kgrad:
Checking vgrad:
pass