flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

I compared the flash_attn_func with torch.nn.functional.scaled_dot_product_attention and found that the results were not as expected. The scaled_dot_product_attention was actually faster.

Open Harry040 opened this issue 1 year ago • 8 comments
trafficstars

I compared the flash_attn_func with torch.nn.functional.scaled_dot_product_attention and found that the results were not as expected. The scaled_dot_product_attention was actually faster.

r = []
for _ in range(100000):
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
    t1 = time.time()
    out, _ = flash_attn_func(q, k, v)
    t2 = time.time()
    r.append(t2 - t1)
print("max", max(r))  # 
print("mean", sum(r)/len(r)) #  result is 2.6926e-05



r = []
for _ in range(100000):
    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
    k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
    v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
    q, k, v = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (q, k, v))
    t1 = time.time()
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    t2 = time.time()
    r.append(t2 - t1)
print("max", max(r))  # 
print("mean", sum(r)/len(r)) #  result is 1.848e-05

Harry040 avatar Jul 22 '24 06:07 Harry040

complete code:

def test():
    import torch
    import time
    import einops
    from flash_attn_interface import flash_attn_func
    device = "cuda"
    dtype = torch.float16
    # set seed
    torch.random.manual_seed(0)
    batch_size = 9
    nheads = 48
    nheads_kv = 48
    d = 64
    seqlen_q = 256
    seqlen_k = 256 
    r = []
    for _ in range(100000):
        q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
        k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
        v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)

        t1 = time.time()
        out, lse = flash_attn_func(q, k, v)
        t2 = time.time()
        #q, k, v = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (q, k, v))
        #out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        # out = einops.rearrange(out, 'B H S D -> B S H D')
        r.append(t2 - t1)
    
    print("max", max(r))
    print("mean", sum(r)/len(r))
    return

Harry040 avatar Jul 22 '24 06:07 Harry040

Please look at existing issues on numerical error. The right thing to compare is (flashattn in fp16 - reference attn in fp32) vs (reference attn in fp16 - reference attn in fp32).

tridao avatar Jul 22 '24 06:07 tridao

Please look at existing issues on numerical error. The right thing to compare is (flashattn in fp16 - reference attn in fp32) vs (reference attn in fp16 - reference attn in fp32).

Thanks reply, How should I do if I want to know which one is faster?

Harry040 avatar Jul 22 '24 07:07 Harry040

https://pytorch.org/tutorials/recipes/recipes/benchmark.html

tridao avatar Jul 22 '24 07:07 tridao

torch benchmark shows scaled_dot_product_attention is faster.

torch.nn.functional.scaled_dot_product_attention is 40.81 us

<torch.utils.benchmark.utils.common.Measurement object at 0x7f8d0d767950> torch_fattn(q, k, v) setup: from main import torch_fattn device = "cuda" dtype = torch.float16 torch.random.manual_seed(0) batch_size = 9 nheads = 48 nheads_kv = 48 d = 64 seqlen_q = 256 seqlen_k = 256

q = torch.randn(batch_size, nheads, seqlen_q, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, nheads, seqlen_k, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, nheads, seqlen_k, d, device=device, dtype=dtype, requires_grad=True)

40.81 us 1 measurement, 50000 runs , 1 thread

flash_attn3 is 43.72 us

<torch.utils.benchmark.utils.common.Measurement object at 0x7f8d0d9669d0> fattn3(q, k, v) setup: from main import fattn3 device = "cuda" dtype = torch.float16 torch.random.manual_seed(0) batch_size = 9 nheads = 48 nheads_kv = 48 #if mha_type == "mha" else (2 if mha_type == "gqa" else 1) d = 64 seqlen_q = 256 seqlen_k = 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) 43.72 us 1 measurement, 50000 runs , 1 thread

Harry040 avatar Jul 25 '24 04:07 Harry040

torch benchmark shows scaled_dot_product_attention is faster.

torch.nn.functional.scaled_dot_product_attention is 40.81 us

<torch.utils.benchmark.utils.common.Measurement object at 0x7f8d0d767950> torch_fattn(q, k, v) setup: from main import torch_fattn device = "cuda" dtype = torch.float16 torch.random.manual_seed(0) batch_size = 9 nheads = 48 nheads_kv = 48 d = 64 seqlen_q = 256 seqlen_k = 256 q = torch.randn(batch_size, nheads, seqlen_q, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, nheads, seqlen_k, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, nheads, seqlen_k, d, device=device, dtype=dtype, requires_grad=True) 40.81 us 1 measurement, 50000 runs , 1 thread

flash_attn3 is 43.72 us

<torch.utils.benchmark.utils.common.Measurement object at 0x7f8d0d9669d0> fattn3(q, k, v) setup: from main import fattn3 device = "cuda" dtype = torch.float16 torch.random.manual_seed(0) batch_size = 9 nheads = 48 nheads_kv = 48 #if mha_type == "mha" else (2 if mha_type == "gqa" else 1) d = 64 seqlen_q = 256 seqlen_k = 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) 43.72 us 1 measurement, 50000 runs , 1 thread

test code:

def fattn3(q, k, v):
    out, lse = flash_attn_func(q, k, v)
    return out

def torch_fattn(q, k, v):
    return torch.nn.functional.scaled_dot_product_attention(q, k, v)

Harry040 avatar Jul 25 '24 04:07 Harry040

I think you should synchronize, since gpu and cpu are not naturally being so. Something like:

sync() # wait for all other kernels to finish
start timing
run
sync() # wait for finishing the execution of the current kernel
end timing

jiuzhengWang avatar Jul 26 '24 07:07 jiuzhengWang

I think you should synchronize, since gpu and cpu are not naturally being so. Something like:

sync() # wait for all other kernels to finish
start timing
run
sync() # wait for finishing the execution of the current kernel
end timing

thanks for your reply.

pytorch benchmark manages synchronining CUDA devices. https://pytorch.org/tutorials/recipes/recipes/benchmark.html

Harry040 avatar Aug 21 '24 07:08 Harry040