kernl icon indicating copy to clipboard operation
kernl copied to clipboard

bug: start_position support for the fused attention kernel

Open ipoletaev opened this issue 2 years ago • 0 comments

Description

Using of a start position index in a fused attention kernel does not work.

Steps to reproduce

START_IDX = 128


def attention_reference(q: torch.Tensor, k: torch.Tensor,
                        v: torch.Tensor) -> torch.Tensor:

    mask_y = torch.full((1, 1, q.size(2), q.size(2)), float("-inf"))
    mask_y = torch.triu(mask_y, diagonal=START_IDX + 1).float()
    att_y = (q @ k.transpose(-2, -1)) * scale
    att_y = att_y + mask_y.to(att_y)
    att_y = torch.nn.functional.softmax(att_y, dim=-1)
    return att_y @ v


q = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
k = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
v = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
scale = 1 / math.sqrt(128)

x = triton_fa(q, k, v, scale, True, START_IDX)
y = attention_reference(q, k, v)
print(torch.max(torch.abs(x - y)))
print(torch.sum(x - y))

Expected Behavior

Almost identical prediction as with the vanilla implementation for any start position index.

Actual Behavior

Returns nan for any START_IDX != 0.

Your environment

torch==2.0.0 triton==2.0.0

Self-service

  • [ ] I would be willing to help fix this bug myself.

Code of Conduct

  • [X] I agree to follow this project's Code of Conduct

ipoletaev avatar Jul 21 '23 06:07 ipoletaev