jax icon indicating copy to clipboard operation
jax copied to clipboard

[NVIDIA] Support custom dtype convert in jax.nn.dot_product_attention

Open kaixih opened this issue 1 year ago • 1 comments

Addressing the issue brought up in 24047.

This PR does this:

# We use this custom dot_general in the QK einsum op of attention to match
# dtypes used in the Flash Attention implementation. For bf16 inputs as an
# example, the fprop is like:
#   bf16 -> dot -> fp32
# Then the bprop is like:
# (1) Without this change:
#   fp32 -> dot -> fp32 -> cvt -> bf16.
# (2) With this change:
#   fp32 -> cvt -> bf16 -> dot -> bf16.

In addition, we adjust the atol/rtol a bit.

cc. @sbodenstein @superbobry

kaixih avatar Oct 16 '24 23:10 kaixih

I think that this is best fixed using the new precision API, rather than custom JVP. What is the motivation for this approach, other than being able to land it a little bit faster?

sbodenstein avatar Oct 21 '24 10:10 sbodenstein

Right, the main motivation is to get it implemented faster. I'm okay with using the new precision API. Do you know when it will be available, especially with the PJRT plugin mentioned here? If it’s coming soon, we can close this PR. If not, do you think we should keep this one open and migrate to the new method later? Also, pinging @dfm for comments.

kaixih avatar Oct 21 '24 16:10 kaixih

An update! With the release of JAX v0.4.35 using dot algorithms as an argument for lax.dot_general now works (https://github.com/jax-ml/jax/pull/24480)! So perhaps we can try to fix https://github.com/jax-ml/jax/issues/24047 using that?

dfm avatar Oct 23 '24 13:10 dfm

Thanks! We'll close this one and create a new one later to implement the newly proposed mechanism mentioned in the comment above.

kaixih avatar Feb 19 '25 00:02 kaixih

@kaixih I've fixed this https://github.com/jax-ml/jax/commit/d5e5b42de85d65fcdc7ab69f38062e00464cf6d9. Let me know if you think its correct.

sbodenstein avatar Feb 19 '25 12:02 sbodenstein