xformers icon indicating copy to clipboard operation
xformers copied to clipboard

JVP/forward mode AD support for flash attention?

Open crowsonkb opened this issue 3 years ago • 4 comments

❓ Questions and Help

I am a researcher working with Stable Diffusion and I need to be able to compute the product of a vector with the model's Jacobian. Some attention implementations that xformers dispatches to seem to work with this, but when it dispatches to flash attention, JVPs fail with the message:

│ /home/kat/venv/lib/python3.10/site-packages/xformers/ops/memory_efficient_attention.py:665 in    │
│ _flash_attn_forward                                                                              │
│                                                                                                  │
│   662 │   │   causal,                                                                            │
│   663 │   │   return_softmax,                                                                    │
│   664 │   ):                                                                                     │
│ ❱ 665 │   │   out, softmax_lse, *rest = _C_flashattention.fwd(                                   │
│   666 │   │   │   q,                                                                             │
│   667 │   │   │   k,                                                                             │
│   668 │   │   │   v,                                                                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage                        
(

Is it possible to enable forward mode AD for flash attention? Failing that, can I control what attention implementation it dispatches to, such that I can choose one that supports forward mode AD?

Thank you, Katherine Crowson

crowsonkb avatar Dec 05 '22 02:12 crowsonkb

Hello,

Do you have a minimum example to repro this issue? Maybe @fmassa you have an idea...

danthe3rd avatar Dec 05 '22 11:12 danthe3rd

Hi,

IIUC computing the jvp requires a custom autograd.Function.jvp method, which we haven't implemented for any of our operators, and it is not in our plans for now to do so (as I believe it would require some substantial work).

@crowsonkb Are you sure there is a mode which works with AD and uses the optimized kernels? Maybe providing a quick repro as @danthe3rd pointed out would be useful for us to understand a bit better the error.

fmassa avatar Dec 05 '22 16:12 fmassa

Hi,

IIUC computing the jvp requires a custom autograd.Function.jvp method, which we haven't implemented for any of our operators, and it is not in our plans for now to do so (as I believe it would require some substantial work).

@crowsonkb Are you sure there is a mode which works with AD and uses the optimized kernels? Maybe providing a quick repro as @danthe3rd pointed out would be useful for us to understand a bit better the error.

Hi,

As some recent works proposed new flow-based framworks, like sCM(https://arxiv.org/pdf/2410.11081), MeanFlow(https://arxiv.org/pdf/2505.13447), which heavily rely on jvp calculation.

However, these works still not validated on large scale models and datasets, due to the lack ability of xformer and Flash Attention for jvp calcuation.

I wonder wheterh xformer has the plan to support jvp?

Looks forward to ur reply.

thans

leon532 avatar May 21 '25 05:05 leon532

We don't currently have plans to support it. If you'd like to implement it yourself, and if it's not too complex, we can consider accepting the contribution.

lw avatar May 21 '25 08:05 lw