xla icon indicating copy to clipboard operation
xla copied to clipboard

Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward)

Open wenscarl opened this issue 1 year ago • 3 comments

As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.

wenscarl avatar Jul 25 '24 19:07 wenscarl

Given that cuDNN's FlashAttention is meant to remain behind a flag (as discussed previously), I wonder whether it still makes sense to integrate this within XLA.

I believe that we should already support calls to scaled dot product attention through JAX directly, is that correct?

bchetioui avatar Sep 27 '24 14:09 bchetioui

Given that cuDNN's FlashAttention is meant to remain behind a flag (as discussed previously), I wonder whether it still makes sense to integrate this within XLA.

I believe that we should already support calls to scaled dot product attention through JAX directly, is that correct?

Do you refer to https://github.com/jax-ml/jax/pull/22670? If so, jax's sdpa api still calls cudnn sdpa from XLA behind the scene. Plus, the forward pass PR is already merged.

wenscarl avatar Sep 30 '24 14:09 wenscarl

Do you refer to https://github.com/jax-ml/jax/pull/22670? If so, jax's sdpa api still calls cudnn sdpa from XLA behind the scene. Plus, the forward pass PR is already merged.

You're right, that seems reasonable, thanks for the clarification.

bchetioui avatar Sep 30 '24 15:09 bchetioui

Gentle ping @wenscarl :)

bchetioui avatar Oct 07 '24 13:10 bchetioui