torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Flex attention error

Open ye-jin-shop opened this issue 1 year ago • 4 comments

I noticed the following error message while running llama 3.1 70b model full finetune:

...
(task, pid=3369) [rank7]:   File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/kernel/flex_attention.py", line 2129, in flex_attention_backward
(task, pid=3369) [rank7]:     broadcasted_grad_key = autotune_select_algorithm(
(task, pid=3369) [rank7]:   File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/select_algorithm.py", line 1887, in autotune_select_algorithm
(task, pid=3369) [rank7]:     return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
(task, pid=3369) [rank7]:   File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/select_algorithm.py", line 1357, in __call__
(task, pid=3369) [rank7]:     raise NoValidChoicesError(
(task, pid=3369) [rank7]: torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 

This happens at the training step

(task, pid=3369)   0%|          | 0/98 [00:00<?, ?it/s]Using flex attention for attention computation since a BlockMask was passed in.

when loss.backward() is acting.

I am using the nightly build of torch. I notice that there is a change in torch today on this file: https://github.com/pytorch/pytorch/commit/7830c213d7547173f060acd406b1f01181a06474. I wonder if they are related or not.

ye-jin-shop avatar Dec 05 '24 01:12 ye-jin-shop

hey, thanks for the issue! If you have availability, do you think you could try a few days old torch nightlies to confirm it?

felipemello1 avatar Dec 05 '24 02:12 felipemello1

hey, thanks for the issue! If you have availability, do you think you could try a few days old torch nightlies to confirm it?

I think I could confirm that at least at Nov 30th it works. The env is

  python3 -m pip install --no-cache-dir https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20241130%2Bcu124-cp310-cp310-linux_x86_64.whl
  python3 -m pip install --no-cache-dir https://download.pytorch.org/whl/nightly/cu124/torchvision-0.20.0.dev20241130%2Bcu124-cp310-cp310-linux_x86_64.whl
  python3 -m pip install --no-cache-dir https://download.pytorch.org/whl/nightly/cu124/torchao-0.7.0.dev20241130%2Bcu124-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl
  python3 -m pip install --no-cache-dir https://download.pytorch.org/whl/nightly/cu124/torchtune-0.5.0.dev20241130%2Bcu124-py3-none-any.whl

I am not sure if you could see "Using flex attention for attention computation since a BlockMask was passed in." from your test data during the training. I feel like my data (or by default) triggers the block mask. And "max_autotune_gemm_backends" is not passed as an env variable so that function failed.

ye-jin-shop avatar Dec 05 '24 05:12 ye-jin-shop

I am seeing the same issue with the 2.6.0 RC from https://download.pytorch.org/whl/test/cu124. Downgrading to the 20241130 nightly fixes it for me.

erikwijmans avatar Jan 07 '25 01:01 erikwijmans

I am seeing the same issue with the 2.6.0 RC from download.pytorch.org/whl/test/cu124. Downgrading to the 20241130 nightly fixes it for me.

Is that the absolute latest that works for you?? Can you try with anything more recent?

joecummings avatar Jan 07 '25 01:01 joecummings

dev20241203 works but dev20241204 does not. So something change between dev20241203 and dev20241204 is probably the culprit. I also spot checked a couple versions between dev20241204 and the RC to see if it was sporadically fixed, but no luck -- I haven't check every single version however.

erikwijmans avatar Jan 09 '25 18:01 erikwijmans

dev20241203 works but dev20241204 does not. So something change between dev20241203 and dev20241204 is probably the culprit. I also spot checked a couple versions between dev20241204 and the RC to see if it was sporadically fixed, but no luck -- I haven't check every single version however.

Okay this gives me a good starting point to investigate - thank you

joecummings avatar Jan 09 '25 21:01 joecummings

Hi @joecummings just wanted to follow up here. I was also able to narrow this down a bit more. It is the backward kernel that is causing the issue. The following example reproduces this error on an A100 GPU for me:

import torch
from torch.nn.attention.flex_attention import flex_attention

flex_attention = torch.compile(flex_attention, dynamic=True)

device = "cuda"
dtype = torch.bfloat16

B = 1
Hq = 28
Hkv = 4
L = 512
S = 512
E = 128

q = torch.randn((B, Hq, L, E), dtype=dtype, device=device, requires_grad=True)
k = torch.randn((B, Hkv, S, E), dtype=dtype, device=device, requires_grad=True)
v = torch.randn((B, Hkv, S, E), dtype=dtype, device=device, requires_grad=True)

attn_output = flex_attention(q, k, v, enable_gqa=True)


attn_output.mean().backward()

I have tried the released version of 2.6.0 and some 2.7.0 nightly and this error still persists.

In the example above I have GQA enabled, but the error is the same with and without GQA.

erikwijmans avatar Feb 05 '25 15:02 erikwijmans

@bdhirsh / @drisspg , do you think you could take a look?

felipemello1 avatar Feb 05 '25 16:02 felipemello1

I can repro the work arounds:

  1. flex_attention = torch.compile(flex_attention, dynamic=True, mode='max-autotune') compile w/ max-autotune. Compile will take longer but you will get better performance (and we will pick a kernel that doesn't use too much shmem)
  2. You can sweep over kernel params to find one that doesn't use too much memory. Our default choice uses too much but you can actuall use autotune to find the best one that will work.

For instance this is the best in the repro:

  triton_flex_attention_backward_7 0.2528 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=7, HAS_FULL_BLOCKS=False, IS_DIVISIBLE=False, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=128, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.08838834764831843, SPARSE_KV_BLOCK_SIZE=1073741824, SPARSE_Q_BLOCK_SIZE=1073741824, V_HEAD_DIM=128, num_stages=4, num_warps=4

So if I pass these in via kwargs:

import torch
from torch.nn.attention.flex_attention import flex_attention

flex_attention = torch.compile(flex_attention, dynamic=False)

device = "cuda"
dtype = torch.bfloat16

B = 1
Hq = 28
Hkv = 4
L = 512
S = 512
E = 128

q = torch.randn((B, Hq, L, E), dtype=dtype, device=device, requires_grad=True)
k = torch.randn((B, Hkv, S, E), dtype=dtype, device=device, requires_grad=True)
v = torch.randn((B, Hkv, S, E), dtype=dtype, device=device, requires_grad=True)

attn_output = 
attn_output = flex_attention(q, k, v, enable_gqa=True, kernel_options = {"BLOCK_M1": 16, "BLOCK_M2": 16, "BLOCK_N1": 16, "BLOCK_N2": 16})

attn_output.mean().backward()

You can't change the num_stages though via input args on 2.6. This is available in nightlies

drisspg avatar Feb 05 '25 20:02 drisspg

hey @drisspg , thanks for taking a look! Can you please clarify:

Solution 1: Add max-autotune flex_attention = torch.compile(flex_attention, dynamic=True, mode='max-autotune') Solution 2: Make dynamic=False and hardcode the kernel_options

solution 2 would probably not work well for torchtune, since users have many types of hardward. Is that right?

felipemello1 avatar Feb 05 '25 21:02 felipemello1

The problem is that the amount of shmem used is dependent on the specific score mod and masked mod used. And the available shared memory is dependent on what GPU you used. We try to provide sensible defaults, but depending on the specific call to flex attention, that default is wrong, and we run under shared memory for the default choice.

Max autotune will try a much wider spread of options and ultimately find one. I have a tracker for this: https://github.com/pytorch/pytorch/issues/139131

I think in terms of practical solutions you will want 1. Also unless you really need to I would not recommend specifying dynamic=True and let automatic dynamic dims kick in when needed

drisspg avatar Feb 05 '25 21:02 drisspg

@erikwijmans , would you like to submit a PR with the change and see if fixes the issue for you? Its ok if you dont have bandwidth.

felipemello1 avatar Feb 05 '25 21:02 felipemello1

Running out of shared memory being the issue makes a lot of sense. Thank you for the suggestions on how to fix it!

@felipemello1 I don't currently have the bandwidth to submit a PR, if you do, go for it!

erikwijmans avatar Feb 06 '25 17:02 erikwijmans

Closing the issue. Fixed in https://github.com/pytorch/torchtune/pull/2357

Thanks for the debugging and the solution!!

felipemello1 avatar Feb 07 '25 01:02 felipemello1