xla
xla copied to clipboard
Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward)
As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Given that cuDNN's FlashAttention is meant to remain behind a flag (as discussed previously), I wonder whether it still makes sense to integrate this within XLA.
I believe that we should already support calls to scaled dot product attention through JAX directly, is that correct?
Given that cuDNN's
FlashAttentionis meant to remain behind a flag (as discussed previously), I wonder whether it still makes sense to integrate this within XLA.I believe that we should already support calls to scaled dot product attention through JAX directly, is that correct?
Do you refer to https://github.com/jax-ml/jax/pull/22670? If so, jax's sdpa api still calls cudnn sdpa from XLA behind the scene. Plus, the forward pass PR is already merged.
Do you refer to https://github.com/jax-ml/jax/pull/22670? If so, jax's sdpa api still calls cudnn sdpa from XLA behind the scene. Plus, the forward pass PR is already merged.
You're right, that seems reasonable, thanks for the clarification.
Gentle ping @wenscarl :)