kernl
kernl copied to clipboard
bug: start_position support for the fused attention kernel
Description
Using of a start position index in a fused attention kernel does not work.
Steps to reproduce
START_IDX = 128
def attention_reference(q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor) -> torch.Tensor:
mask_y = torch.full((1, 1, q.size(2), q.size(2)), float("-inf"))
mask_y = torch.triu(mask_y, diagonal=START_IDX + 1).float()
att_y = (q @ k.transpose(-2, -1)) * scale
att_y = att_y + mask_y.to(att_y)
att_y = torch.nn.functional.softmax(att_y, dim=-1)
return att_y @ v
q = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
k = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
v = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
scale = 1 / math.sqrt(128)
x = triton_fa(q, k, v, scale, True, START_IDX)
y = attention_reference(q, k, v)
print(torch.max(torch.abs(x - y)))
print(torch.sum(x - y))
Expected Behavior
Almost identical prediction as with the vanilla implementation for any start position index.
Actual Behavior
Returns nan for any START_IDX != 0.
Your environment
torch==2.0.0 triton==2.0.0
Self-service
- [ ] I would be willing to help fix this bug myself.
Code of Conduct
- [X] I agree to follow this project's Code of Conduct