No support for 4D attention? `RuntimeError: cu_seqlens_q must have shape (batch_size + 1)`
It seems like 2D attention works with flash attention, but not 4D attention masks. In this example, I have the same inputs except for the attention masks. The 4D attention masks create an all-to-all attention, for example (more concretely, my actual code does prefix-LM, but this is more for demonstration)
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/pythia-70m",
attn_implementation="flash_attention_2",
dtype=torch.float16,
device_map="auto"
)
hidden_size = model.config.hidden_size
input_embeds = torch.randn(1, 5, hidden_size, device=model.device, dtype=model.dtype)
position_ids = torch.arange(5, device=model.device).unsqueeze(0)
attention_mask_2d = torch.ones(1, 5, device=model.device)
outputs = model(
inputs_embeds=input_embeds,
position_ids=position_ids,
attention_mask=attention_mask_2d,
output_hidden_states=True
)
print(f"Hidden states shape: {outputs.hidden_states[-1].shape}")
print(f"Logits shape: {outputs.logits.shape}")
# All to all attention, for example
attention_mask_4d = torch.ones(1, 1, 5, 5, device=model.device)
outputs = model(
inputs_embeds=input_embeds,
position_ids=position_ids,
attention_mask=attention_mask_4d,
output_hidden_states=True
)
print(f"Hidden states shape: {outputs.hidden_states[-1].shape}")
print(f"Logits shape: {outputs.logits.shape}")
The output is:
Hidden states shape: torch.Size([1, 5, 512]) Logits shape: torch.Size([1, 5, 50304])
And then error:
Traceback (most recent call last):
File "/pkg/modal/_runtime/container_io_manager.py", line 778, in handle_input_exception
yield
File "/pkg/modal/_container_entrypoint.py", line 243, in run_input_sync
res = io_context.call_finalized_function()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/pkg/modal/_runtime/container_io_manager.py", line 197, in call_finalized_function
res = self.finalized_function.callable(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/training/modal.py", line 143, in sample_remote
outputs = model(
^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/transformers/utils/generic.py", line 940, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 581, in forward
outputs: BaseModelOutputWithPast = self.gpt_neox(
^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
outputs = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 479, in forward
outputs = layer(
^^^^^^
File "/opt/conda/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 221, in forward
attn_output, attn_weights = self.attention(
^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 178, in forward
attn_output, attn_weights = attention_interface(
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/transformers/integrations/flash_attention.py", line 66, in flash_attention_forward
attn_output = _flash_attention_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 616, in _flash_attention_forward
out_unpad = flash_varlen_fn(
^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 1443, in flash_attn_varlen_func
return FlashAttnVarlenFunc.apply(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 925, in forward
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/_library/autograd.py", line 111, in autograd_impl
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/_ops.py", line 836, in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs) # type: ignore[return-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 344, in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 377, in wrapped_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 165, in _flash_attn_varlen_forward
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cu_seqlens_q must have shape (batch_size + 1)
Looking at NVIDIA's solution, they use squeeze(1) twice to make the attention_mask 2D.
def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
"""
Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
the samples in a batch.
"""
mask = mask.squeeze(1).squeeze(1)
reduced_mask = mask.logical_not().sum(dim=1)
cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
zero = torch.zeros(1, dtype=torch.int32, device="cuda")
cu_seqlens = torch.cat((zero, cu_seqlens))
return cu_seqlens
Hope it helps
Thanks @mhoangvslev , however, my attention mask is [batch_size, 1, max_seqlen, max_seqlen]
This is also the case for me. The transformers attention_mask is 4D like you described. Is there any reason you can use 4D?
I need to use 4D for
Hi @tridao , would really love support for this 🙏🏻
We're starting have FlexAttention implemented on top of FA4, so that should eventually work for this case.
Hi @tridao - any progress on this? sorry, I am not technical enough to understand all the low level stuff here...
plz check the flexattn tests to see if any of those apply to your case
I had Claude run a benchmark. flex_attention works with 4D masks.
Everything ran on an NVIDIA DGX Spark
● Results: Batch=128, Seq=128, Realistic Masks
| Implementation | pythia-14m 2D | pythia-14m 4D | pythia-70m 2D | pythia-70m 4D |
|---|---|---|---|---|
| eager | 20.51 ms | 19.13 ms | 52.36 ms | 51.71 ms |
| sdpa | 19.14 ms | 14.56 ms | 44.25 ms | 41.47 ms |
| flash_attention_2 | 22.72 ms | ❌ FAIL | 47.77 ms | ❌ FAIL |
| flash_attention_3 | ❌ FAIL | ❌ FAIL | ❌ FAIL | ❌ FAIL |
| flex_attention | 14.55 ms | 14.78 ms | 41.62 ms | 42.01 ms |
| sdpa_paged | 14.40 ms | 14.45 ms | 41.99 ms | 41.39 ms |
| eager_paged | 18.30 ms | 18.95 ms | 49.51 ms | 51.32 ms |
Key Findings at Scale (batch=128)
- Best performers (virtually tied): - sdpa_paged: 14.40-14.45 ms (14m), 41.39-41.99 ms (70m) - flex_attention: 14.55-14.78 ms (14m), 41.62-42.01 ms (70m) - sdpa with 4D: 14.56 ms (14m), 41.47 ms (70m)
- flash_attention_2 still fails with 4D masks
- Model scaling: ~3x slowdown from 14m→70m (as expected for 5x more parameters)
- 4D mask support (all work): - ✅ eager, sdpa, flex_attention, sdpa_paged, eager_paged - ❌ flash_attention_2, flash_attention_3
- Performance ranking for 4D masks (pythia-70m): a. sdpa_paged: 41.39 ms b. sdpa: 41.47 ms c. flex_attention: 42.01 ms d. eager_paged: 51.32 ms e. eager: 51.71 ms
Reproduction
Dockerfile
FROM nvcr.io/nvidia/pytorch:25.11-py3
RUN pip install --no-cache-dir transformers accelerate -q
WORKDIR /workspace
run.sh
#!/bin/bash
BATCH_SIZE=128
SEQ_LENGTH=128
NUM_RUNS=20
run_test() {
local impl="$1"
local model="$2"
local mask="$3"
result=$(docker run --rm --gpus all --ipc=host \
-v /home/amit/dev/tmp/flash-attention-test:/workspace \
-v /shared/.cache/huggingface:/root/.cache/huggingface \
flash-test python -c "
import torch
import time
from transformers import AutoModelForCausalLM
BATCH_SIZE = ${BATCH_SIZE}
SEQ_LENGTH = ${SEQ_LENGTH}
NUM_RUNS = ${NUM_RUNS}
try:
model = AutoModelForCausalLM.from_pretrained(
'${model}',
torch_dtype=torch.float16,
device_map='auto',
attn_implementation='${impl}'
)
hidden_size = model.config.hidden_size
input_embeds = torch.randn(BATCH_SIZE, SEQ_LENGTH, hidden_size, device=model.device, dtype=model.dtype)
position_ids = torch.arange(SEQ_LENGTH, device=model.device).unsqueeze(0).expand(BATCH_SIZE, -1)
if '${mask}' == '4d':
# 4D mask: half 0s, half 1s (first half of keys are masked out)
attention_mask = torch.zeros(BATCH_SIZE, 1, SEQ_LENGTH, SEQ_LENGTH, device=model.device, dtype=model.dtype)
attention_mask[:, :, :, SEQ_LENGTH//2:] = 1.0
else:
# 2D mask: diagonal pattern (alternating 0s and 1s)
attention_mask = torch.zeros(BATCH_SIZE, SEQ_LENGTH, device=model.device)
attention_mask[:, ::2] = 1.0 # every other token is valid
# Warmup
with torch.no_grad():
for _ in range(3):
_ = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=attention_mask)
# Timed
torch.cuda.synchronize()
start = time.perf_counter()
with torch.no_grad():
for _ in range(NUM_RUNS):
_ = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=attention_mask)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) * 1000 / NUM_RUNS
print(f'OK:{elapsed:.2f}')
except Exception as e:
print(f'FAIL:{str(e)[:60]}')
" 2>&1 | grep -E "^(OK|FAIL):" | tail -1)
echo "$impl,$model,$mask,$result"
}
# Header
echo ""
echo "impl,model,mask,result"
for impl in eager sdpa flash_attention_2 flash_attention_3 flex_attention sdpa_paged eager_paged; do
for model in EleutherAI/pythia-14m EleutherAI/pythia-70m; do
for mask in 2d 4d; do
run_test "$impl" "$model" "$mask"
done
done
done
i mean flexattn in this repo (flash_attn.cute) https://github.com/Dao-AILab/flash-attention/blob/main/tests/cute/test_score_mod.py
in general 4D attn mask isn't the right abstraction for prefix-lm (4d mask is too general and you'll pay for it with slow down). @drisspg do we have example of prefix lm w flexattn in this repo?
I dont think there is any current example, just generated one by calling flex_attention w/ "BACKNED" = "FLASH =
# kernel path: /tmp/torchinductor_dev/y4/cy4su5k46zgriqpkqosly2lalii733le3qyfgzilhvjzbncfvvyc.py
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [flex_attention]
# Source node to ATen node mapping:
# flex_attention => flex_attention
# Graph fragment:
# %arg0_1 : Tensor "bf16[4, 32, 8192, 128][33554432, 1048576, 128, 1]cuda:0" = PlaceHolder[target=arg0_1]
# %arg1_1 : Tensor "bf16[4, 32, 8192, 128][33554432, 1048576, 128, 1]cuda:0" = PlaceHolder[target=arg1_1]
# %arg2_1 : Tensor "bf16[4, 32, 8192, 128][33554432, 1048576, 128, 1]cuda:0" = PlaceHolder[target=arg2_1]
# %buf1 : Tensor "f32[4, 32, 8192][262144, 8192, 1]cuda:0" = PlaceHolder[target=buf1]
# %arg4_1 : Tensor "i32[4, 32, 32][1024, 32, 1]cuda:0" = PlaceHolder[target=arg4_1]
# %arg3_1 : Tensor "i32[4, 32, 32, 64][65536, 2048, 64, 1]cuda:0" = PlaceHolder[target=arg3_1]
# %arg5_1 : Tensor "i32[4, 32, 32][1024, 32, 1]cuda:0" = PlaceHolder[target=arg5_1]
# %arg6_1 : Tensor "i32[4, 32, 32, 64][65536, 2048, 64, 1]cuda:0" = PlaceHolder[target=arg6_1]
# %flex_attention : [num_users=1] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (8192, 8192, %arg4_1, %arg3_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 256, 128, %sdpa_mask0), 0.08838834764831843, {BACKEND: FLASH, PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
# return %getitem
cutedsl_fused_flex_attention_a32649c7 = async_compile.cutedsl('cutedsl_fused_flex_attention_a32649c7', r'''
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import cuda.bindings.driver as cuda
from cutlass._mlir.dialects import math as mlir_math
import operator
from torch._inductor.codegen.cutedsl._cutedsl_utils import ssa_to_indexable, result_to_ssa
# Kernel function signature: cutedsl_fused_flex_attention_a32649c7
def cutedsl_fused_flex_attention_a32649c7_main(arg_Q, arg_K, arg_V, arg_LOGSUMEXP, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, stream):
Q = arg_Q
K = arg_K
V = arg_V
LOGSUMEXP = arg_LOGSUMEXP
KV_NUM_BLKS = arg_KV_NUM_BLKS
KV_IDX = arg_KV_IDX
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
FULL_KV_IDX = arg_FULL_KV_IDX
from flash_attn.cute.interface import _flash_attn_fwd
from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch
# Transpose tensors for _flash_attn_fwd compatibility (B,H,M,D) -> (B,M,H,D)
q_transposed = Q.transpose(1, 2)
k_transposed = K.transpose(1, 2)
v_transposed = V.transpose(1, 2)
@cute.jit
def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = tSrS_ssa
tSrS_ssa = tmp0
return tSrS_ssa
score_mod.__cute_hash__ = "cutedsl_fused_flex_attention_a32649c7_score"
# (B,M,H,D) -> (B,H,M,D)
output = out_ptr0
output_transposed = output.transpose(1, 2)
@cute.jit
def mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp1 = kv_idx
tmp2 = operator.lt(tmp1, cute.full_like(tmp1, 2048))
tmp3 = q_idx
tmp4 = operator.ge(tmp3, tmp1)
mask_mod_output = (False | tmp2) | tmp4
return mask_mod_output
mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_a32649c7_mask"
block_sparse_tensors = BlockSparseTensorsTorch(KV_NUM_BLKS, KV_IDX, FULL_KV_NUM_BLKS, FULL_KV_IDX)
# Collect any additional tensor buffers that were added during modifications
buffers = None
# Out and LSE filled inplace
_flash_attn_fwd(
q_transposed,
k_transposed,
v_transposed,
softmax_scale=0.08838834764831843,
return_lse=True,
score_mod=score_mod,
mask_mod=mask_mod,
out=output_transposed,
lse=LOGSUMEXP,
block_sparse_tensors=block_sparse_tensors,
aux_tensors=buffers
)
'''
This is with static 2048 prefix size
So in non autogen code
import operator
prefix_len = 2048
...
@cute.jit
def mask_mod_cute(b_idx, h_idx, q_idx, kv_idx, aux_tensors):
return (operator.lt(kv_idx, cute.full_like(kv_idx, prefix_len)) |
operator.ge(q_idx, kv_idx))