triton softmax support multi-batch
-
Support another batch dimension for softmax. In training or batch inference, we may add a batch dimension as the first dimension of some tensors. However, we use the third dimension(
tensor.shape[2]) as thehead_dim, which would be influenced. In this pr, I modify it totensor.shape[-3]to solve this problem. CUDA kernel is modified as well. -
Enable test_atten_core, this test is skipped by default and never be used.
I found it would fail when the batch dimension comes to 2.
Update:
It works incorrectly before this commit. For the way getting bias_ptr not supports multi-batch.
And I have fixed the triton version.
CUDA version may support multi-batch one day :(