flash-attention
flash-attention copied to clipboard
I compared the flash_attn_func with torch.nn.functional.scaled_dot_product_attention and found that the results were not as expected. The scaled_dot_product_attention was actually faster.
I compared the flash_attn_func with torch.nn.functional.scaled_dot_product_attention and found that the results were not as expected. The scaled_dot_product_attention was actually faster.
r = []
for _ in range(100000):
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
t1 = time.time()
out, _ = flash_attn_func(q, k, v)
t2 = time.time()
r.append(t2 - t1)
print("max", max(r)) #
print("mean", sum(r)/len(r)) # result is 2.6926e-05
r = []
for _ in range(100000):
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
q, k, v = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (q, k, v))
t1 = time.time()
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
t2 = time.time()
r.append(t2 - t1)
print("max", max(r)) #
print("mean", sum(r)/len(r)) # result is 1.848e-05
complete code:
def test():
import torch
import time
import einops
from flash_attn_interface import flash_attn_func
device = "cuda"
dtype = torch.float16
# set seed
torch.random.manual_seed(0)
batch_size = 9
nheads = 48
nheads_kv = 48
d = 64
seqlen_q = 256
seqlen_k = 256
r = []
for _ in range(100000):
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
t1 = time.time()
out, lse = flash_attn_func(q, k, v)
t2 = time.time()
#q, k, v = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (q, k, v))
#out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
# out = einops.rearrange(out, 'B H S D -> B S H D')
r.append(t2 - t1)
print("max", max(r))
print("mean", sum(r)/len(r))
return
Please look at existing issues on numerical error. The right thing to compare is (flashattn in fp16 - reference attn in fp32) vs (reference attn in fp16 - reference attn in fp32).
Please look at existing issues on numerical error. The right thing to compare is (flashattn in fp16 - reference attn in fp32) vs (reference attn in fp16 - reference attn in fp32).
Thanks reply, How should I do if I want to know which one is faster?
https://pytorch.org/tutorials/recipes/recipes/benchmark.html
torch benchmark shows scaled_dot_product_attention is faster.
torch.nn.functional.scaled_dot_product_attention is 40.81 us
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8d0d767950> torch_fattn(q, k, v) setup: from main import torch_fattn device = "cuda" dtype = torch.float16 torch.random.manual_seed(0) batch_size = 9 nheads = 48 nheads_kv = 48 d = 64 seqlen_q = 256 seqlen_k = 256
q = torch.randn(batch_size, nheads, seqlen_q, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, nheads, seqlen_k, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, nheads, seqlen_k, d, device=device, dtype=dtype, requires_grad=True)
40.81 us 1 measurement, 50000 runs , 1 thread
flash_attn3 is 43.72 us
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8d0d9669d0> fattn3(q, k, v) setup: from main import fattn3 device = "cuda" dtype = torch.float16 torch.random.manual_seed(0) batch_size = 9 nheads = 48 nheads_kv = 48 #if mha_type == "mha" else (2 if mha_type == "gqa" else 1) d = 64 seqlen_q = 256 seqlen_k = 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) 43.72 us 1 measurement, 50000 runs , 1 thread
torch benchmark shows scaled_dot_product_attention is faster.
torch.nn.functional.scaled_dot_product_attention is 40.81 us
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8d0d767950> torch_fattn(q, k, v) setup: from main import torch_fattn device = "cuda" dtype = torch.float16 torch.random.manual_seed(0) batch_size = 9 nheads = 48 nheads_kv = 48 d = 64 seqlen_q = 256 seqlen_k = 256 q = torch.randn(batch_size, nheads, seqlen_q, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, nheads, seqlen_k, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, nheads, seqlen_k, d, device=device, dtype=dtype, requires_grad=True) 40.81 us 1 measurement, 50000 runs , 1 thread
flash_attn3 is 43.72 us
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8d0d9669d0> fattn3(q, k, v) setup: from main import fattn3 device = "cuda" dtype = torch.float16 torch.random.manual_seed(0) batch_size = 9 nheads = 48 nheads_kv = 48 #if mha_type == "mha" else (2 if mha_type == "gqa" else 1) d = 64 seqlen_q = 256 seqlen_k = 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) 43.72 us 1 measurement, 50000 runs , 1 thread
test code:
def fattn3(q, k, v):
out, lse = flash_attn_func(q, k, v)
return out
def torch_fattn(q, k, v):
return torch.nn.functional.scaled_dot_product_attention(q, k, v)
I think you should synchronize, since gpu and cpu are not naturally being so. Something like:
sync() # wait for all other kernels to finish
start timing
run
sync() # wait for finishing the execution of the current kernel
end timing
I think you should synchronize, since gpu and cpu are not naturally being so. Something like:
sync() # wait for all other kernels to finish start timing run sync() # wait for finishing the execution of the current kernel end timing
thanks for your reply.
pytorch benchmark manages synchronining CUDA devices. https://pytorch.org/tutorials/recipes/recipes/benchmark.html