nanoGPT
nanoGPT copied to clipboard
Implement's torch SDPA for FlashAttention Kernel
Implements torch sdpa for mem_efficient kernel support!
Using the mem_efficient kernel results in a ~15.5% faster training time per batch, going from a ~154ms/batch baseline to ~130ms/batch. (Ran on 8 x NVIDIA Corporation GA100 [A100 SXM4 80GB])
Potentially also improves on overflow protection, since mem_efficient kernel scales q & k matrices before multiplying.
Just fyi since dropout is already 0.0 in this branch - setting dropout to 0.0 is currently necessary for kernel support, but this will no longer be necessary after https://github.com/pytorch/pytorch/pull/92917 is landed.
great!! running some benchmarking and adjusting the code slightly...
(also I think this code is wrong because the line
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
was not deleted. but it's ok i'll adjust.
I'm seeing slightly different results:
(y-y2).abs().max()
0.0078
slightly unsettling. Any idea where this is from?
I also get a lot of really scary warnings from torch.compile ...
Heads up I merged a slight modification in this commit: https://github.com/karpathy/nanoGPT/commit/ae06d0b15a9111cbe2ce66b0f1be9ae29c1ecbbe
Let me know if any comments
I also get a lot of really scary warnings from torch.compile ...
If the torch.compile errors look like:
[2023-01-30 23:39:16,738] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._scaled_dot_product_efficient_attention.default then this is okay.
For aten._scaled_dot_product_flash_attention.default and .aten._scaled_dot_product_efficient_attention.default we actually don't want inductor to generate any triton code for these two kernels because at least as of right now it would not be as performant. Hence the warning is saying that it is fallingback to the 'eager' implementations which in this particular case is fast.
I'm seeing slightly different results:
(y-y2).abs().max() 0.0078slightly unsettling. Any idea where this is from?
Ah, yeah also noticed some similar discrepancies. My best guess is that this is due to some float overflows due to some differences in rounding, etc. There's a more detailed explanation of what I think may be going on in the 'Overflow protection' section here: https://dev-discuss.pytorch.org/t/performance-gains-w-nanogpt-using-sdpa-custom-kernel/1015#overflow-protection-8
@drisspg it's much worse than that. Just running train.py prints:
compiling the model... (takes a ~minute)
[2023-01-30 23:47:24,269] torch._inductor.graph: [WARNING] Creating implicit fallback for:
target: aten._scaled_dot_product_efficient_attention.default
args[0]: TensorBox(
PermuteView(data=View(
ReinterpretView(
StorageBox(
ExternKernelOut(
name=buf5,
layout=FixedLayout('cuda', torch.bfloat16, size=[12288, 2304], stride=[2304, 1]),
inputs=[ReinterpretView(
StorageBox(
ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[786432, 768, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
tmp0 = load(arg76_1, i1 + 1024 * i0)
tmp1 = load(arg25_1, i2 + 768 * (tmp0))
tmp2 = index_expr(i1, torch.int64)
tmp3 = load(arg26_1, i2 + 768 * (tmp2))
tmp4 = tmp1 + tmp3
tmp5 = load(buf2, i1 + 1024 * i0)
tmp6 = index_expr(768, torch.float32)
tmp7 = tmp5 / tmp6
tmp8 = tmp4 - tmp7
tmp9 = load(buf1, i1 + 1024 * i0)
tmp10 = index_expr(768, torch.float32)
tmp11 = tmp9 / tmp10
tmp12 = constant(1e-05, torch.float32)
tmp13 = tmp11 + tmp12
tmp14 = rsqrt(tmp13)
tmp15 = tmp8 * tmp14
tmp16 = load(arg0_1, i2)
tmp17 = tmp15 * tmp16
tmp18 = to_dtype(tmp17, torch.bfloat16)
return tmp18
,
ranges=[12, 1024, 768],
origins={convert_element_type_1, arg76_1, embedding_1, add_1, arg26_1, rsqrt, arange, unsqueeze, add, arg25_1, arg0_1, var_mean, mul_1, embedding, sub, mul}
))
),
FixedLayout('cuda', torch.bfloat16, size=(12288, 768), stride=[768, 1]),
no origins?
), ReinterpretView(
StorageBox(
ComputedBuffer(name='buf4', layout=FixedLayout('cuda', torch.bfloat16, size=[2304, 768], stride=[768, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
tmp0 = load(arg27_1, i1 + 768 * i0)
tmp1 = to_dtype(tmp0, torch.bfloat16)
return tmp1
,
ranges=[2304, 768],
origins={convert_element_type, arg27_1}
))
),
FixedLayout('cuda', torch.bfloat16, size=[768, 2304], stride=[1, 768]),
no origins?
)],
constant_args=(),
kwargs={},
output_view=None,
origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, var_mean, mul_1, sub, embedding, mul}
)
),
FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[2359296, 2304, 1]),
no origins?
),
size=(12, 1024, 12, 64),
reindex=lambda i0, i1, i2, i3: [i0, i1, 64*i2 + i3],
origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, view_2, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, _unsafe_view, var_mean, mul_1, sub, embedding, mul, split}
), dims=[0, 2, 1, 3])
)
args[1]: TensorBox(
PermuteView(data=View(
ReinterpretView(
StorageBox(
ExternKernelOut(
name=buf5,
layout=FixedLayout('cuda', torch.bfloat16, size=[12288, 2304], stride=[2304, 1]),
inputs=[ReinterpretView(
StorageBox(
ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[786432, 768, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
tmp0 = load(arg76_1, i1 + 1024 * i0)
tmp1 = load(arg25_1, i2 + 768 * (tmp0))
tmp2 = index_expr(i1, torch.int64)
tmp3 = load(arg26_1, i2 + 768 * (tmp2))
tmp4 = tmp1 + tmp3
tmp5 = load(buf2, i1 + 1024 * i0)
tmp6 = index_expr(768, torch.float32)
tmp7 = tmp5 / tmp6
tmp8 = tmp4 - tmp7
tmp9 = load(buf1, i1 + 1024 * i0)
tmp10 = index_expr(768, torch.float32)
tmp11 = tmp9 / tmp10
tmp12 = constant(1e-05, torch.float32)
tmp13 = tmp11 + tmp12
tmp14 = rsqrt(tmp13)
tmp15 = tmp8 * tmp14
tmp16 = load(arg0_1, i2)
tmp17 = tmp15 * tmp16
tmp18 = to_dtype(tmp17, torch.bfloat16)
return tmp18
,
ranges=[12, 1024, 768],
origins={convert_element_type_1, arg76_1, embedding_1, add_1, arg26_1, rsqrt, arange, unsqueeze, add, arg25_1, arg0_1, var_mean, mul_1, embedding, sub, mul}
))
),
FixedLayout('cuda', torch.bfloat16, size=(12288, 768), stride=[768, 1]),
no origins?
), ReinterpretView(
StorageBox(
ComputedBuffer(name='buf4', layout=FixedLayout('cuda', torch.bfloat16, size=[2304, 768], stride=[768, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
tmp0 = load(arg27_1, i1 + 768 * i0)
tmp1 = to_dtype(tmp0, torch.bfloat16)
return tmp1
,
ranges=[2304, 768],
origins={convert_element_type, arg27_1}
))
),
FixedLayout('cuda', torch.bfloat16, size=[768, 2304], stride=[1, 768]),
no origins?
)],
constant_args=(),
kwargs={},
output_view=None,
origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, var_mean, mul_1, sub, embedding, mul}
)
),
FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[2359296, 2304, 1], offset=768),
no origins?
),
size=(12, 1024, 12, 64),
reindex=lambda i0, i1, i2, i3: [i0, i1, 64*i2 + i3],
origins={view_1, mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, _unsafe_view, var_mean, mul_1, sub, embedding, mul, split}
), dims=[0, 2, 1, 3])
)
args[2]: TensorBox(
PermuteView(data=View(
ReinterpretView(
StorageBox(
ExternKernelOut(
name=buf5,
layout=FixedLayout('cuda', torch.bfloat16, size=[12288, 2304], stride=[2304, 1]),
inputs=[ReinterpretView(
StorageBox(
ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[786432, 768, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
tmp0 = load(arg76_1, i1 + 1024 * i0)
tmp1 = load(arg25_1, i2 + 768 * (tmp0))
tmp2 = index_expr(i1, torch.int64)
tmp3 = load(arg26_1, i2 + 768 * (tmp2))
tmp4 = tmp1 + tmp3
tmp5 = load(buf2, i1 + 1024 * i0)
tmp6 = index_expr(768, torch.float32)
tmp7 = tmp5 / tmp6
tmp8 = tmp4 - tmp7
tmp9 = load(buf1, i1 + 1024 * i0)
tmp10 = index_expr(768, torch.float32)
tmp11 = tmp9 / tmp10
tmp12 = constant(1e-05, torch.float32)
tmp13 = tmp11 + tmp12
tmp14 = rsqrt(tmp13)
tmp15 = tmp8 * tmp14
tmp16 = load(arg0_1, i2)
tmp17 = tmp15 * tmp16
tmp18 = to_dtype(tmp17, torch.bfloat16)
return tmp18
,
ranges=[12, 1024, 768],
origins={convert_element_type_1, arg76_1, embedding_1, add_1, arg26_1, rsqrt, arange, unsqueeze, add, arg25_1, arg0_1, var_mean, mul_1, embedding, sub, mul}
))
),
FixedLayout('cuda', torch.bfloat16, size=(12288, 768), stride=[768, 1]),
no origins?
), ReinterpretView(
StorageBox(
ComputedBuffer(name='buf4', layout=FixedLayout('cuda', torch.bfloat16, size=[2304, 768], stride=[768, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
tmp0 = load(arg27_1, i1 + 768 * i0)
tmp1 = to_dtype(tmp0, torch.bfloat16)
return tmp1
,
ranges=[2304, 768],
origins={convert_element_type, arg27_1}
))
),
FixedLayout('cuda', torch.bfloat16, size=[768, 2304], stride=[1, 768]),
no origins?
)],
constant_args=(),
kwargs={},
output_view=None,
origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, var_mean, mul_1, sub, embedding, mul}
)
),
FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[2359296, 2304, 1], offset=1536),
no origins?
),
size=(12, 1024, 12, 64),
reindex=lambda i0, i1, i2, i3: [i0, i1, 64*i2 + i3],
origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, _unsafe_view, var_mean, mul_1, sub, embedding, mul, view_3, split}
), dims=[0, 2, 1, 3])
)
args[3]: False
args[4]: True
[2023-01-30 23:47:24,279] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._scaled_dot_product_efficient_attention.default
step 0: train loss 10.9649, val loss 10.9658
[2023-01-30 23:47:32,725] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._scaled_dot_product_efficient_attention.default
[2023-01-30 23:47:35,128] torch._inductor.graph: [WARNING] Creating implicit fallback for:
target: aten._scaled_dot_product_efficient_attention_backward.default
args[0]: TensorBox(
ReinterpretView(
StorageBox(
ExternKernelOut(
name=buf46,
layout=FixedLayout('cuda', torch.bfloat16, size=[12288, 768], stride=[768, 1]),
inputs=[ComputedBuffer(name='buf44', layout=FixedLayout('cuda', torch.bfloat16, size=(12288, 768), stride=[768, 1]), data=Pointwise(
'cuda',
torch.bfloat16,
tmp0 = load(buf43, i1 + 768 * ModularIndexing(i0, 1, 1024) + 786432 * ModularIndexing(i0, 1024, 12))
tmp1 = to_dtype(tmp0, torch.bfloat16)
return tmp1
,
ranges=(12288, 768),
origins={view_155}
)), InputBuffer(name='permute_111', layout=FixedLayout('cuda', torch.bfloat16, size=[768, 768], stride=[768, 1]))],
constant_args=(),
kwargs={},
output_view=None,
origins={mm_56, view_155, permute_111}
)
),
FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[786432, 64, 768, 1]),
no origins?
)
)
args[1]: TensorBox(StorageBox(
InputBuffer(name='permute_90', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[2359296, 64, 2304, 1]))
))
args[2]: TensorBox(StorageBox(
InputBuffer(name='permute_89', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[2359296, 64, 2304, 1]))
))
args[3]: TensorBox(StorageBox(
InputBuffer(name='permute_91', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[2359296, 64, 2304, 1]))
))
args[4]: TensorBox(StorageBox(
InputBuffer(name='getitem_104', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[786432, 64, 768, 1]))
))
args[5]: TensorBox(StorageBox(
InputBuffer(name='getitem_105', layout=FixedLayout('cuda', torch.float32, size=[12, 12, 1024], stride=[12288, 1024, 1]))
))
args[6]: True
[2023-01-30 23:47:35,135] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._scaled_dot_product_efficient_attention_backward.default
We should really lower the logging level of those messages to INFO. As @drisspg says, it's just saying that it's using the eager kernel, but it prints out all the associated inductor IR in the process of saying that.
this was merged now so closing the issue
For aten._scaled_dot_product_flash_attention.default and .aten._scaled_dot_product_efficient_attention.default we actually don't want inductor to generate any triton code for these two kernels because at least as of right now it would not be as performant
What's the current perf gap between the triton version and the aten version? Or can you directly map to triton's custom flash attention implementation?