nanoGPT icon indicating copy to clipboard operation
nanoGPT copied to clipboard

Implement's torch SDPA for FlashAttention Kernel

Open LucasLLC opened this issue 2 years ago • 9 comments

Implements torch sdpa for mem_efficient kernel support!

Using the mem_efficient kernel results in a ~15.5% faster training time per batch, going from a ~154ms/batch baseline to ~130ms/batch. (Ran on 8 x NVIDIA Corporation GA100 [A100 SXM4 80GB])

Potentially also improves on overflow protection, since mem_efficient kernel scales q & k matrices before multiplying.

Just fyi since dropout is already 0.0 in this branch - setting dropout to 0.0 is currently necessary for kernel support, but this will no longer be necessary after https://github.com/pytorch/pytorch/pull/92917 is landed.

LucasLLC avatar Jan 30 '23 20:01 LucasLLC

great!! running some benchmarking and adjusting the code slightly...

karpathy avatar Jan 30 '23 21:01 karpathy

(also I think this code is wrong because the line

y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

was not deleted. but it's ok i'll adjust.

karpathy avatar Jan 30 '23 21:01 karpathy

I'm seeing slightly different results:

(y-y2).abs().max()
0.0078

slightly unsettling. Any idea where this is from?

karpathy avatar Jan 30 '23 21:01 karpathy

I also get a lot of really scary warnings from torch.compile ...

karpathy avatar Jan 30 '23 21:01 karpathy

Heads up I merged a slight modification in this commit: https://github.com/karpathy/nanoGPT/commit/ae06d0b15a9111cbe2ce66b0f1be9ae29c1ecbbe

Let me know if any comments

karpathy avatar Jan 30 '23 23:01 karpathy

I also get a lot of really scary warnings from torch.compile ...

If the torch.compile errors look like: [2023-01-30 23:39:16,738] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._scaled_dot_product_efficient_attention.default then this is okay.

For aten._scaled_dot_product_flash_attention.default and .aten._scaled_dot_product_efficient_attention.default we actually don't want inductor to generate any triton code for these two kernels because at least as of right now it would not be as performant. Hence the warning is saying that it is fallingback to the 'eager' implementations which in this particular case is fast.

drisspg avatar Jan 30 '23 23:01 drisspg

I'm seeing slightly different results:

(y-y2).abs().max()
0.0078

slightly unsettling. Any idea where this is from?

Ah, yeah also noticed some similar discrepancies. My best guess is that this is due to some float overflows due to some differences in rounding, etc. There's a more detailed explanation of what I think may be going on in the 'Overflow protection' section here: https://dev-discuss.pytorch.org/t/performance-gains-w-nanogpt-using-sdpa-custom-kernel/1015#overflow-protection-8

LucasLLC avatar Jan 30 '23 23:01 LucasLLC

@drisspg it's much worse than that. Just running train.py prints:

compiling the model... (takes a ~minute)
[2023-01-30 23:47:24,269] torch._inductor.graph: [WARNING] Creating implicit fallback for:
  target: aten._scaled_dot_product_efficient_attention.default
  args[0]: TensorBox(
    PermuteView(data=View(
      ReinterpretView(
        StorageBox(
          ExternKernelOut(
            name=buf5,
            layout=FixedLayout('cuda', torch.bfloat16, size=[12288, 2304], stride=[2304, 1]),
            inputs=[ReinterpretView(
              StorageBox(
                ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[786432, 768, 1]), data=Pointwise(
                  'cuda',
                  torch.bfloat16,
                  tmp0 = load(arg76_1, i1 + 1024 * i0)
                  tmp1 = load(arg25_1, i2 + 768 * (tmp0))
                  tmp2 = index_expr(i1, torch.int64)
                  tmp3 = load(arg26_1, i2 + 768 * (tmp2))
                  tmp4 = tmp1 + tmp3
                  tmp5 = load(buf2, i1 + 1024 * i0)
                  tmp6 = index_expr(768, torch.float32)
                  tmp7 = tmp5 / tmp6
                  tmp8 = tmp4 - tmp7
                  tmp9 = load(buf1, i1 + 1024 * i0)
                  tmp10 = index_expr(768, torch.float32)
                  tmp11 = tmp9 / tmp10
                  tmp12 = constant(1e-05, torch.float32)
                  tmp13 = tmp11 + tmp12
                  tmp14 = rsqrt(tmp13)
                  tmp15 = tmp8 * tmp14
                  tmp16 = load(arg0_1, i2)
                  tmp17 = tmp15 * tmp16
                  tmp18 = to_dtype(tmp17, torch.bfloat16)
                  return tmp18
                  ,
                  ranges=[12, 1024, 768],
                  origins={convert_element_type_1, arg76_1, embedding_1, add_1, arg26_1, rsqrt, arange, unsqueeze, add, arg25_1, arg0_1, var_mean, mul_1, embedding, sub, mul}
                ))
              ),
              FixedLayout('cuda', torch.bfloat16, size=(12288, 768), stride=[768, 1]),
              no origins?
            ), ReinterpretView(
              StorageBox(
                ComputedBuffer(name='buf4', layout=FixedLayout('cuda', torch.bfloat16, size=[2304, 768], stride=[768, 1]), data=Pointwise(
                  'cuda',
                  torch.bfloat16,
                  tmp0 = load(arg27_1, i1 + 768 * i0)
                  tmp1 = to_dtype(tmp0, torch.bfloat16)
                  return tmp1
                  ,
                  ranges=[2304, 768],
                  origins={convert_element_type, arg27_1}
                ))
              ),
              FixedLayout('cuda', torch.bfloat16, size=[768, 2304], stride=[1, 768]),
              no origins?
            )],
            constant_args=(),
            kwargs={},
            output_view=None,
            origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, var_mean, mul_1, sub, embedding, mul}
          )
        ),
        FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[2359296, 2304, 1]),
        no origins?
      ),
      size=(12, 1024, 12, 64),
      reindex=lambda i0, i1, i2, i3: [i0, i1, 64*i2 + i3],
      origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, view_2, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, _unsafe_view, var_mean, mul_1, sub, embedding, mul, split}
    ), dims=[0, 2, 1, 3])
  )
  args[1]: TensorBox(
    PermuteView(data=View(
      ReinterpretView(
        StorageBox(
          ExternKernelOut(
            name=buf5,
            layout=FixedLayout('cuda', torch.bfloat16, size=[12288, 2304], stride=[2304, 1]),
            inputs=[ReinterpretView(
              StorageBox(
                ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[786432, 768, 1]), data=Pointwise(
                  'cuda',
                  torch.bfloat16,
                  tmp0 = load(arg76_1, i1 + 1024 * i0)
                  tmp1 = load(arg25_1, i2 + 768 * (tmp0))
                  tmp2 = index_expr(i1, torch.int64)
                  tmp3 = load(arg26_1, i2 + 768 * (tmp2))
                  tmp4 = tmp1 + tmp3
                  tmp5 = load(buf2, i1 + 1024 * i0)
                  tmp6 = index_expr(768, torch.float32)
                  tmp7 = tmp5 / tmp6
                  tmp8 = tmp4 - tmp7
                  tmp9 = load(buf1, i1 + 1024 * i0)
                  tmp10 = index_expr(768, torch.float32)
                  tmp11 = tmp9 / tmp10
                  tmp12 = constant(1e-05, torch.float32)
                  tmp13 = tmp11 + tmp12
                  tmp14 = rsqrt(tmp13)
                  tmp15 = tmp8 * tmp14
                  tmp16 = load(arg0_1, i2)
                  tmp17 = tmp15 * tmp16
                  tmp18 = to_dtype(tmp17, torch.bfloat16)
                  return tmp18
                  ,
                  ranges=[12, 1024, 768],
                  origins={convert_element_type_1, arg76_1, embedding_1, add_1, arg26_1, rsqrt, arange, unsqueeze, add, arg25_1, arg0_1, var_mean, mul_1, embedding, sub, mul}
                ))
              ),
              FixedLayout('cuda', torch.bfloat16, size=(12288, 768), stride=[768, 1]),
              no origins?
            ), ReinterpretView(
              StorageBox(
                ComputedBuffer(name='buf4', layout=FixedLayout('cuda', torch.bfloat16, size=[2304, 768], stride=[768, 1]), data=Pointwise(
                  'cuda',
                  torch.bfloat16,
                  tmp0 = load(arg27_1, i1 + 768 * i0)
                  tmp1 = to_dtype(tmp0, torch.bfloat16)
                  return tmp1
                  ,
                  ranges=[2304, 768],
                  origins={convert_element_type, arg27_1}
                ))
              ),
              FixedLayout('cuda', torch.bfloat16, size=[768, 2304], stride=[1, 768]),
              no origins?
            )],
            constant_args=(),
            kwargs={},
            output_view=None,
            origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, var_mean, mul_1, sub, embedding, mul}
          )
        ),
        FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[2359296, 2304, 1], offset=768),
        no origins?
      ),
      size=(12, 1024, 12, 64),
      reindex=lambda i0, i1, i2, i3: [i0, i1, 64*i2 + i3],
      origins={view_1, mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, _unsafe_view, var_mean, mul_1, sub, embedding, mul, split}
    ), dims=[0, 2, 1, 3])
  )
  args[2]: TensorBox(
    PermuteView(data=View(
      ReinterpretView(
        StorageBox(
          ExternKernelOut(
            name=buf5,
            layout=FixedLayout('cuda', torch.bfloat16, size=[12288, 2304], stride=[2304, 1]),
            inputs=[ReinterpretView(
              StorageBox(
                ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[786432, 768, 1]), data=Pointwise(
                  'cuda',
                  torch.bfloat16,
                  tmp0 = load(arg76_1, i1 + 1024 * i0)
                  tmp1 = load(arg25_1, i2 + 768 * (tmp0))
                  tmp2 = index_expr(i1, torch.int64)
                  tmp3 = load(arg26_1, i2 + 768 * (tmp2))
                  tmp4 = tmp1 + tmp3
                  tmp5 = load(buf2, i1 + 1024 * i0)
                  tmp6 = index_expr(768, torch.float32)
                  tmp7 = tmp5 / tmp6
                  tmp8 = tmp4 - tmp7
                  tmp9 = load(buf1, i1 + 1024 * i0)
                  tmp10 = index_expr(768, torch.float32)
                  tmp11 = tmp9 / tmp10
                  tmp12 = constant(1e-05, torch.float32)
                  tmp13 = tmp11 + tmp12
                  tmp14 = rsqrt(tmp13)
                  tmp15 = tmp8 * tmp14
                  tmp16 = load(arg0_1, i2)
                  tmp17 = tmp15 * tmp16
                  tmp18 = to_dtype(tmp17, torch.bfloat16)
                  return tmp18
                  ,
                  ranges=[12, 1024, 768],
                  origins={convert_element_type_1, arg76_1, embedding_1, add_1, arg26_1, rsqrt, arange, unsqueeze, add, arg25_1, arg0_1, var_mean, mul_1, embedding, sub, mul}
                ))
              ),
              FixedLayout('cuda', torch.bfloat16, size=(12288, 768), stride=[768, 1]),
              no origins?
            ), ReinterpretView(
              StorageBox(
                ComputedBuffer(name='buf4', layout=FixedLayout('cuda', torch.bfloat16, size=[2304, 768], stride=[768, 1]), data=Pointwise(
                  'cuda',
                  torch.bfloat16,
                  tmp0 = load(arg27_1, i1 + 768 * i0)
                  tmp1 = to_dtype(tmp0, torch.bfloat16)
                  return tmp1
                  ,
                  ranges=[2304, 768],
                  origins={convert_element_type, arg27_1}
                ))
              ),
              FixedLayout('cuda', torch.bfloat16, size=[768, 2304], stride=[1, 768]),
              no origins?
            )],
            constant_args=(),
            kwargs={},
            output_view=None,
            origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, var_mean, mul_1, sub, embedding, mul}
          )
        ),
        FixedLayout('cuda', torch.bfloat16, size=[12, 1024, 768], stride=[2359296, 2304, 1], offset=1536),
        no origins?
      ),
      size=(12, 1024, 12, 64),
      reindex=lambda i0, i1, i2, i3: [i0, i1, 64*i2 + i3],
      origins={mm, convert_element_type_1, arg76_1, convert_element_type, arg27_1, embedding_1, add_1, arg26_1, rsqrt, view, arange, unsqueeze, permute, add, arg25_1, arg0_1, _unsafe_view, var_mean, mul_1, sub, embedding, mul, view_3, split}
    ), dims=[0, 2, 1, 3])
  )
  args[3]: False
  args[4]: True
[2023-01-30 23:47:24,279] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._scaled_dot_product_efficient_attention.default
step 0: train loss 10.9649, val loss 10.9658
[2023-01-30 23:47:32,725] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._scaled_dot_product_efficient_attention.default
[2023-01-30 23:47:35,128] torch._inductor.graph: [WARNING] Creating implicit fallback for:
  target: aten._scaled_dot_product_efficient_attention_backward.default
  args[0]: TensorBox(
    ReinterpretView(
      StorageBox(
        ExternKernelOut(
          name=buf46,
          layout=FixedLayout('cuda', torch.bfloat16, size=[12288, 768], stride=[768, 1]),
          inputs=[ComputedBuffer(name='buf44', layout=FixedLayout('cuda', torch.bfloat16, size=(12288, 768), stride=[768, 1]), data=Pointwise(
            'cuda',
            torch.bfloat16,
            tmp0 = load(buf43, i1 + 768 * ModularIndexing(i0, 1, 1024) + 786432 * ModularIndexing(i0, 1024, 12))
            tmp1 = to_dtype(tmp0, torch.bfloat16)
            return tmp1
            ,
            ranges=(12288, 768),
            origins={view_155}
          )), InputBuffer(name='permute_111', layout=FixedLayout('cuda', torch.bfloat16, size=[768, 768], stride=[768, 1]))],
          constant_args=(),
          kwargs={},
          output_view=None,
          origins={mm_56, view_155, permute_111}
        )
      ),
      FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[786432, 64, 768, 1]),
      no origins?
    )
  )
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='permute_90', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[2359296, 64, 2304, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='permute_89', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[2359296, 64, 2304, 1]))
  ))
  args[3]: TensorBox(StorageBox(
    InputBuffer(name='permute_91', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[2359296, 64, 2304, 1]))
  ))
  args[4]: TensorBox(StorageBox(
    InputBuffer(name='getitem_104', layout=FixedLayout('cuda', torch.bfloat16, size=[12, 12, 1024, 64], stride=[786432, 64, 768, 1]))
  ))
  args[5]: TensorBox(StorageBox(
    InputBuffer(name='getitem_105', layout=FixedLayout('cuda', torch.float32, size=[12, 12, 1024], stride=[12288, 1024, 1]))
  ))
  args[6]: True
[2023-01-30 23:47:35,135] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._scaled_dot_product_efficient_attention_backward.default

karpathy avatar Jan 30 '23 23:01 karpathy

We should really lower the logging level of those messages to INFO. As @drisspg says, it's just saying that it's using the eager kernel, but it prints out all the associated inductor IR in the process of saying that.

bertmaher avatar Jan 31 '23 02:01 bertmaher

this was merged now so closing the issue

karpathy avatar Feb 03 '23 03:02 karpathy

For aten._scaled_dot_product_flash_attention.default and .aten._scaled_dot_product_efficient_attention.default we actually don't want inductor to generate any triton code for these two kernels because at least as of right now it would not be as performant

What's the current perf gap between the triton version and the aten version? Or can you directly map to triton's custom flash attention implementation?

Jokeren avatar Feb 08 '23 08:02 Jokeren