flash-attention
flash-attention copied to clipboard
IMA with split k kernel.
Summary
For more details see this PyTorch issue: https://github.com/pytorch/pytorch/issues/131257
I was able to reproduce on: 74b0761ff7efc7b90d4e5aeb529c1b2a09a7458c with following script:
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
from flash_attn import flash_attn_func
def test_flash_attention():
device = "cuda"
dtype = torch.bfloat16
batch_size = 1
n_heads = 1
seq_len_q = 1
seq_len_k = 257
head_dim = 32
is_causal = False
dropout_p = 0.0
scale = 1 / head_dim
# Create input tensors
query = torch.rand(
batch_size,
n_heads,
seq_len_q,
head_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
key = torch.rand(
batch_size,
n_heads,
seq_len_k,
head_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
value = torch.rand(
batch_size,
n_heads,
seq_len_k,
head_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
out_flash = flash_attn_func(
query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2),
)
out_flash = out_flash.transpose(1, 2)
# with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
# out = F.scaled_dot_product_attention(
# query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale
# )
print("Output flash shape: ", out_flash.shape)
# print("Output shape:", out.shape)
# print("Output sum:", out.sum().item())
if __name__ == "__main__":
test_flash_attention()
TORCH_DISABLE_ADDR2LINE=1 PYTORCH_NO_CUDA_MEMORY_CACHING=1 compute-sanitizer --tool memcheck --log-file ima.txt python ima.py