flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

How to obtain differentiable softmax_lse

Open albert-cwkuo opened this issue 1 year ago • 7 comments

Hi @tridao ,

I'd like to use the value of softmax_lse in my model and back-propagate gradient through it. However, I do see another discussion saying that it is not taken into account during the backward pass.

Does newer version support back-prop for softmax_lse? If not, how easy/difficult will it be to modify the cuda code to support that? Thanks in advance for any advise!

albert-cwkuo avatar Aug 08 '24 17:08 albert-cwkuo

Backprop on softmax_lse is not supported. Feel free to work on it if you need it. You just have to work out the gradient and then implement it. I suspect it's not too bad.

tridao avatar Aug 08 '24 17:08 tridao

Thanks for your reply :). I am new to cuda and fuse kernel, etc. Do you mind pointing me to which part I should dig into and an abstract idea of how this could be implemented? Thanks a lot.

albert-cwkuo avatar Aug 08 '24 18:08 albert-cwkuo

It depends on what the gradient looks like. What's the gradient for softmax_lse?

tridao avatar Aug 08 '24 18:08 tridao

Sorry I don't quite get what you mean :sweat_smile:.

What I expect is that given the softmax_lse of q and k of shape, let's say N x nhead x Sq x Sk, I am able to compute the gradient of the same shape for each element of softmax_lse. Afterward, the gradient of each element in softmax_lse can be further propagated through q and k. Only first order gradient is needed in my case.

albert-cwkuo avatar Aug 08 '24 18:08 albert-cwkuo

Well you need to work out how to compute the gradient mathematically (e.g see FlashAttention paper appendix B2) before implementing it.

tridao avatar Aug 08 '24 18:08 tridao

Thanks a lot @tridao for the reference! Took me some time to derive the gradient of LSE w.r.t. q & k analytically:

For dq: $\frac{\partial}{\partial q}LSE(q k^T, \ \text{axis=1}) = \text{softmax}(q k^T,\ \text{axis=1})k$

For dk: $\frac{\partial}{\partial k}LSE(q k^T, \ \text{axis=1}) = \text{softmax}(q k^T,\ \text{axis=1})^T q$

Any advice for how and where should I plug this into the cuda code such that the returned softmax lse supports backprop?

albert-cwkuo avatar Aug 09 '24 07:08 albert-cwkuo

Bwd pass code is here: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_kernel.h you can follow acc_dk and acc_dq to see how it's being computed right now.

tridao avatar Aug 09 '24 16:08 tridao

@tridao

  1. Does "Backprop on softmax_lse is not supported" mean that that backprop wouldn't work correctly even if I use softmax_lse only and immediately to merge attention with another kv-set's softmax_lse value? I'm not sure if the OP wants to do operations on softmax_lse or only use it for merging attention. My usecase is to just merge attention so I'm wondering if the existing repo can already support that?
  2. So would that mean that we can use softmax_lse as inference time trick to combine attention for same set of q but different sets of kv, but we cannot train with it?

@albert-cwkuo Did you end up implementing it, would love to take a look at your code if you have!

Ashwin-Ramesh2607 avatar Nov 21 '24 04:11 Ashwin-Ramesh2607

@Ashwin-Ramesh2607 To train with combined attention, you can simply pass combined_lse as lse and combined_out as out to the backward kernel and accumulate dq, because a subset of dk and dv only depends on the corresponding subset of k and v, plus the combined out, dout and lse.

Image

see: https://github.com/zhuzilin/ring-flash-attention/blob/main/ring_flash_attn/ring_flash_attn.py

Alternatively, you can just use detached lse, but replace the out with combined_out. It's obvious that out will only affect preprocessing kernel and D.

See also: https://x.com/YouJiacheng/status/1930148252982108169

YouJiacheng avatar Jun 04 '25 06:06 YouJiacheng