flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

No support for 4D attention? `RuntimeError: cu_seqlens_q must have shape (batch_size + 1)`

Open AmitMY opened this issue 3 months ago • 6 comments

It seems like 2D attention works with flash attention, but not 4D attention masks. In this example, I have the same inputs except for the attention masks. The 4D attention masks create an all-to-all attention, for example (more concretely, my actual code does prefix-LM, but this is more for demonstration)

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/pythia-70m",
    attn_implementation="flash_attention_2",
    dtype=torch.float16,
    device_map="auto"
)

hidden_size = model.config.hidden_size
input_embeds = torch.randn(1, 5, hidden_size, device=model.device, dtype=model.dtype)
position_ids = torch.arange(5, device=model.device).unsqueeze(0)
attention_mask_2d = torch.ones(1, 5, device=model.device)

outputs = model(
    inputs_embeds=input_embeds,
    position_ids=position_ids,
    attention_mask=attention_mask_2d,
    output_hidden_states=True
)

print(f"Hidden states shape: {outputs.hidden_states[-1].shape}")
print(f"Logits shape: {outputs.logits.shape}")

# All to all attention, for example
attention_mask_4d = torch.ones(1, 1, 5, 5, device=model.device)
outputs = model(
    inputs_embeds=input_embeds,
    position_ids=position_ids,
    attention_mask=attention_mask_4d,
    output_hidden_states=True
)

print(f"Hidden states shape: {outputs.hidden_states[-1].shape}")
print(f"Logits shape: {outputs.logits.shape}")

The output is:

Hidden states shape: torch.Size([1, 5, 512]) Logits shape: torch.Size([1, 5, 50304])

And then error:

Traceback (most recent call last):
  File "/pkg/modal/_runtime/container_io_manager.py", line 778, in handle_input_exception                                                                                                                                                
    yield                                                                                                                                                                                                                                
  File "/pkg/modal/_container_entrypoint.py", line 243, in run_input_sync                                                                                                                                                                
    res = io_context.call_finalized_function()                                                                                                                                                                                           
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                           
  File "/pkg/modal/_runtime/container_io_manager.py", line 197, in call_finalized_function                                                                                                                                               
    res = self.finalized_function.callable(*args, **kwargs)                                                                                                                                                                              
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                              
  File "/app/training/modal.py", line 143, in sample_remote                                                                                                                                                                              
    outputs = model(                                                                                                                                                                                                                     
              ^^^^^^                                                                                                                                                                                                                     
  File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                              
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                              
  File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                 
  File "/opt/conda/lib/python3.12/site-packages/transformers/utils/generic.py", line 940, in wrapper                                                                                                                                     
    output = func(self, *args, **kwargs)                                                                                                                                                                                                 
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                 
  File "/opt/conda/lib/python3.12/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 581, in forward                                                                                                                 
    outputs: BaseModelOutputWithPast = self.gpt_neox(                                                                                                                                                                                    
                                       ^^^^^^^^^^^^^^                                                                                                                                                                                    
  File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                              
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                              
  File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                 
  File "/opt/conda/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper                                                                                                                                    
    outputs = func(self, *args, **kwargs)                                                                                                                                                                                                
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                
  File "/opt/conda/lib/python3.12/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 479, in forward                                                                                                                 
    outputs = layer(                                                                                                                                                                                                                     
              ^^^^^^                                                                                                                                                                                                                     
  File "/opt/conda/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__                                                                                                                                   
    return super().__call__(*args, **kwargs)                                                                                                                                                                                             
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                             
  File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                              
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                              
  File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                 
  File "/opt/conda/lib/python3.12/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 221, in forward                                                                                                                 
    attn_output, attn_weights = self.attention(                                                                                                                                                                                          
                                ^^^^^^^^^^^^^^^                                                                                                                                                                                          
  File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                              
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                              
  File "/opt/conda/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                 
  File "/opt/conda/lib/python3.12/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 178, in forward                                                                                                                 
    attn_output, attn_weights = attention_interface(                                                                                                                                                                                     
                                ^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                     
  File "/opt/conda/lib/python3.12/site-packages/transformers/integrations/flash_attention.py", line 66, in flash_attention_forward                                                                                                       
    attn_output = _flash_attention_forward(                                                                                                                                                                                              
                  ^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                              
  File "/opt/conda/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 616, in _flash_attention_forward                                                                                                   
    out_unpad = flash_varlen_fn(                                                                                                                                                                                                         
                ^^^^^^^^^^^^^^^^                                                                                                                                                                                                         
  File "/opt/conda/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 1443, in flash_attn_varlen_func                                                                                                                
    return FlashAttnVarlenFunc.apply(                                                                                                                                                                                                    
           ^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                    
  File "/opt/conda/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply                                                                                                                                          
    return super().apply(*args, **kwargs)  # type: ignore[misc]                                                                                                                                                                          
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                
  File "/opt/conda/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 925, in forward                                                                                                                                
    out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(                                                                                                                                                    
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                    
  File "/opt/conda/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__                                                                                                                                                   
    return self._op(*args, **kwargs)                                                                                                                                                                                                     
           ^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                     
  File "/opt/conda/lib/python3.12/site-packages/torch/_library/autograd.py", line 111, in autograd_impl                                                                                                                                  
    result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))                                                                                                                                                                 
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                 
  File "/opt/conda/lib/python3.12/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad                                                                                                                                 
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)                                                                                                                                                          
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                          
  File "/opt/conda/lib/python3.12/site-packages/torch/_ops.py", line 836, in redispatch                                                                                                                                                  
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)  # type: ignore[return-value]                                                                                                                                          
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                        
  File "/opt/conda/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 344, in backend_impl                                                                                                                                 
    result = self._backend_fns[device_type](*args, **kwargs)                                                                                                                                                                             
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                             
  File "/opt/conda/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner                                                                                                                                                    
    return disable_fn(*args, **kwargs)                                                                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                   
  File "/opt/conda/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn                                                                                                                                           
    return fn(*args, **kwargs)                                                                                                                                                                                                           
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                           
  File "/opt/conda/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 377, in wrapped_fn                                                                                                                                   
    return fn(*args, **kwargs)                                                                                                                                                                                                           
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                           
  File "/opt/conda/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 165, in _flash_attn_varlen_forward                                                                                                             
    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(                                                                                                                                                                    
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                    
RuntimeError: cu_seqlens_q must have shape (batch_size + 1)        

AmitMY avatar Sep 02 '25 09:09 AmitMY

Looking at NVIDIA's solution, they use squeeze(1) twice to make the attention_mask 2D.

def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
    """
    Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32
    tensor of shape [batch_size + 1] containing the cumulative sequence lengths of
    the samples in a batch.
    """
    mask = mask.squeeze(1).squeeze(1)
    reduced_mask = mask.logical_not().sum(dim=1)
    cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32)
    zero = torch.zeros(1, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cat((zero, cu_seqlens))

    return cu_seqlens

Hope it helps

mhoangvslev avatar Sep 16 '25 09:09 mhoangvslev

Thanks @mhoangvslev , however, my attention mask is [batch_size, 1, max_seqlen, max_seqlen]

AmitMY avatar Sep 16 '25 11:09 AmitMY

This is also the case for me. The transformers attention_mask is 4D like you described. Is there any reason you can use 4D?

mhoangvslev avatar Sep 16 '25 11:09 mhoangvslev

I need to use 4D for

Image

AmitMY avatar Sep 16 '25 12:09 AmitMY

Hi @tridao , would really love support for this 🙏🏻

AmitMY avatar Nov 09 '25 06:11 AmitMY

We're starting have FlexAttention implemented on top of FA4, so that should eventually work for this case.

tridao avatar Nov 09 '25 13:11 tridao

Hi @tridao - any progress on this? sorry, I am not technical enough to understand all the low level stuff here...

AmitMY avatar Dec 05 '25 11:12 AmitMY

plz check the flexattn tests to see if any of those apply to your case

tridao avatar Dec 05 '25 11:12 tridao

I had Claude run a benchmark. flex_attention works with 4D masks. Everything ran on an NVIDIA DGX Spark


● Results: Batch=128, Seq=128, Realistic Masks

Implementation pythia-14m 2D pythia-14m 4D pythia-70m 2D pythia-70m 4D
eager 20.51 ms 19.13 ms 52.36 ms 51.71 ms
sdpa 19.14 ms 14.56 ms 44.25 ms 41.47 ms
flash_attention_2 22.72 ms ❌ FAIL 47.77 ms ❌ FAIL
flash_attention_3 ❌ FAIL ❌ FAIL ❌ FAIL ❌ FAIL
flex_attention 14.55 ms 14.78 ms 41.62 ms 42.01 ms
sdpa_paged 14.40 ms 14.45 ms 41.99 ms 41.39 ms
eager_paged 18.30 ms 18.95 ms 49.51 ms 51.32 ms

Key Findings at Scale (batch=128)

  1. Best performers (virtually tied): - sdpa_paged: 14.40-14.45 ms (14m), 41.39-41.99 ms (70m) - flex_attention: 14.55-14.78 ms (14m), 41.62-42.01 ms (70m) - sdpa with 4D: 14.56 ms (14m), 41.47 ms (70m)
  2. flash_attention_2 still fails with 4D masks
  3. Model scaling: ~3x slowdown from 14m→70m (as expected for 5x more parameters)
  4. 4D mask support (all work): - ✅ eager, sdpa, flex_attention, sdpa_paged, eager_paged - ❌ flash_attention_2, flash_attention_3
  5. Performance ranking for 4D masks (pythia-70m): a. sdpa_paged: 41.39 ms b. sdpa: 41.47 ms c. flex_attention: 42.01 ms d. eager_paged: 51.32 ms e. eager: 51.71 ms

Reproduction

Dockerfile

FROM nvcr.io/nvidia/pytorch:25.11-py3

RUN pip install --no-cache-dir transformers accelerate -q

WORKDIR /workspace

run.sh

#!/bin/bash

BATCH_SIZE=128
SEQ_LENGTH=128
NUM_RUNS=20

run_test() {
    local impl="$1"
    local model="$2"
    local mask="$3"

    result=$(docker run --rm --gpus all --ipc=host \
        -v /home/amit/dev/tmp/flash-attention-test:/workspace \
        -v /shared/.cache/huggingface:/root/.cache/huggingface \
        flash-test python -c "
import torch
import time
from transformers import AutoModelForCausalLM

BATCH_SIZE = ${BATCH_SIZE}
SEQ_LENGTH = ${SEQ_LENGTH}
NUM_RUNS = ${NUM_RUNS}

try:
    model = AutoModelForCausalLM.from_pretrained(
        '${model}',
        torch_dtype=torch.float16,
        device_map='auto',
        attn_implementation='${impl}'
    )
    hidden_size = model.config.hidden_size
    input_embeds = torch.randn(BATCH_SIZE, SEQ_LENGTH, hidden_size, device=model.device, dtype=model.dtype)
    position_ids = torch.arange(SEQ_LENGTH, device=model.device).unsqueeze(0).expand(BATCH_SIZE, -1)

    if '${mask}' == '4d':
        # 4D mask: half 0s, half 1s (first half of keys are masked out)
        attention_mask = torch.zeros(BATCH_SIZE, 1, SEQ_LENGTH, SEQ_LENGTH, device=model.device, dtype=model.dtype)
        attention_mask[:, :, :, SEQ_LENGTH//2:] = 1.0
    else:
        # 2D mask: diagonal pattern (alternating 0s and 1s)
        attention_mask = torch.zeros(BATCH_SIZE, SEQ_LENGTH, device=model.device)
        attention_mask[:, ::2] = 1.0  # every other token is valid

    # Warmup
    with torch.no_grad():
        for _ in range(3):
            _ = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=attention_mask)

    # Timed
    torch.cuda.synchronize()
    start = time.perf_counter()
    with torch.no_grad():
        for _ in range(NUM_RUNS):
            _ = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=attention_mask)
    torch.cuda.synchronize()
    elapsed = (time.perf_counter() - start) * 1000 / NUM_RUNS
    print(f'OK:{elapsed:.2f}')
except Exception as e:
    print(f'FAIL:{str(e)[:60]}')
" 2>&1 | grep -E "^(OK|FAIL):" | tail -1)

    echo "$impl,$model,$mask,$result"
}

# Header
echo ""
echo "impl,model,mask,result"

for impl in eager sdpa flash_attention_2 flash_attention_3 flex_attention sdpa_paged eager_paged; do
    for model in EleutherAI/pythia-14m EleutherAI/pythia-70m; do
        for mask in 2d 4d; do
            run_test "$impl" "$model" "$mask"
        done
    done
done

AmitMY avatar Dec 05 '25 13:12 AmitMY

i mean flexattn in this repo (flash_attn.cute) https://github.com/Dao-AILab/flash-attention/blob/main/tests/cute/test_score_mod.py

tridao avatar Dec 06 '25 00:12 tridao

in general 4D attn mask isn't the right abstraction for prefix-lm (4d mask is too general and you'll pay for it with slow down). @drisspg do we have example of prefix lm w flexattn in this repo?

tridao avatar Dec 06 '25 00:12 tridao

I dont think there is any current example, just generated one by calling flex_attention w/ "BACKNED" = "FLASH =

# kernel path: /tmp/torchinductor_dev/y4/cy4su5k46zgriqpkqosly2lalii733le3qyfgzilhvjzbncfvvyc.py
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [flex_attention]
# Source node to ATen node mapping:
#   flex_attention => flex_attention
# Graph fragment:
#   %arg0_1 : Tensor "bf16[4, 32, 8192, 128][33554432, 1048576, 128, 1]cuda:0" = PlaceHolder[target=arg0_1]
#   %arg1_1 : Tensor "bf16[4, 32, 8192, 128][33554432, 1048576, 128, 1]cuda:0" = PlaceHolder[target=arg1_1]
#   %arg2_1 : Tensor "bf16[4, 32, 8192, 128][33554432, 1048576, 128, 1]cuda:0" = PlaceHolder[target=arg2_1]
#   %buf1 : Tensor "f32[4, 32, 8192][262144, 8192, 1]cuda:0" = PlaceHolder[target=buf1]
#   %arg4_1 : Tensor "i32[4, 32, 32][1024, 32, 1]cuda:0" = PlaceHolder[target=arg4_1]
#   %arg3_1 : Tensor "i32[4, 32, 32, 64][65536, 2048, 64, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %arg5_1 : Tensor "i32[4, 32, 32][1024, 32, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %arg6_1 : Tensor "i32[4, 32, 32, 64][65536, 2048, 64, 1]cuda:0" = PlaceHolder[target=arg6_1]
#   %flex_attention : [num_users=1] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (8192, 8192, %arg4_1, %arg3_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 256, 128, %sdpa_mask0), 0.08838834764831843, {BACKEND: FLASH, PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {})
#   return %getitem
cutedsl_fused_flex_attention_a32649c7 = async_compile.cutedsl('cutedsl_fused_flex_attention_a32649c7', r'''
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import cuda.bindings.driver as cuda
from cutlass._mlir.dialects import math as mlir_math
import operator
from torch._inductor.codegen.cutedsl._cutedsl_utils import ssa_to_indexable, result_to_ssa

# Kernel function signature: cutedsl_fused_flex_attention_a32649c7
def cutedsl_fused_flex_attention_a32649c7_main(arg_Q, arg_K, arg_V, arg_LOGSUMEXP, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, stream):
    Q = arg_Q
    K = arg_K
    V = arg_V
    LOGSUMEXP = arg_LOGSUMEXP
    KV_NUM_BLKS = arg_KV_NUM_BLKS
    KV_IDX = arg_KV_IDX
    FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
    FULL_KV_IDX = arg_FULL_KV_IDX


    from flash_attn.cute.interface import _flash_attn_fwd
    from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch

    # Transpose tensors for _flash_attn_fwd compatibility (B,H,M,D) -> (B,M,H,D)
    q_transposed = Q.transpose(1, 2)
    k_transposed = K.transpose(1, 2)
    v_transposed = V.transpose(1, 2)

    @cute.jit
    def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):

        tmp0 = tSrS_ssa
        tSrS_ssa = tmp0

        return tSrS_ssa
    score_mod.__cute_hash__ = "cutedsl_fused_flex_attention_a32649c7_score"

    # (B,M,H,D) -> (B,H,M,D)
    output = out_ptr0
    output_transposed = output.transpose(1, 2)


    @cute.jit
    def mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors):

        tmp1 = kv_idx
        tmp2 = operator.lt(tmp1, cute.full_like(tmp1, 2048))
        tmp3 = q_idx
        tmp4 = operator.ge(tmp3, tmp1)
        mask_mod_output = (False | tmp2) | tmp4

        return mask_mod_output
    mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_a32649c7_mask"
    block_sparse_tensors = BlockSparseTensorsTorch(KV_NUM_BLKS, KV_IDX, FULL_KV_NUM_BLKS, FULL_KV_IDX)


    # Collect any additional tensor buffers that were added during modifications
    buffers = None
    # Out and LSE filled inplace
    _flash_attn_fwd(
        q_transposed,
        k_transposed,
        v_transposed,
        softmax_scale=0.08838834764831843,
        return_lse=True,
        score_mod=score_mod,
        mask_mod=mask_mod,
        out=output_transposed,
        lse=LOGSUMEXP,
        block_sparse_tensors=block_sparse_tensors,
        aux_tensors=buffers
    )
'''

This is with static 2048 prefix size

So in non autogen code

import operator
prefix_len = 2048
...
    @cute.jit
    def mask_mod_cute(b_idx, h_idx, q_idx, kv_idx, aux_tensors):
        return (operator.lt(kv_idx, cute.full_like(kv_idx, prefix_len)) | 
                operator.ge(q_idx, kv_idx))

drisspg avatar Dec 06 '25 00:12 drisspg