[NVIDIA] Support custom dtype convert in jax.nn.dot_product_attention
Addressing the issue brought up in 24047.
This PR does this:
# We use this custom dot_general in the QK einsum op of attention to match
# dtypes used in the Flash Attention implementation. For bf16 inputs as an
# example, the fprop is like:
# bf16 -> dot -> fp32
# Then the bprop is like:
# (1) Without this change:
# fp32 -> dot -> fp32 -> cvt -> bf16.
# (2) With this change:
# fp32 -> cvt -> bf16 -> dot -> bf16.
In addition, we adjust the atol/rtol a bit.
cc. @sbodenstein @superbobry
I think that this is best fixed using the new precision API, rather than custom JVP. What is the motivation for this approach, other than being able to land it a little bit faster?
Right, the main motivation is to get it implemented faster. I'm okay with using the new precision API. Do you know when it will be available, especially with the PJRT plugin mentioned here? If it’s coming soon, we can close this PR. If not, do you think we should keep this one open and migrate to the new method later? Also, pinging @dfm for comments.
An update! With the release of JAX v0.4.35 using dot algorithms as an argument for lax.dot_general now works (https://github.com/jax-ml/jax/pull/24480)! So perhaps we can try to fix https://github.com/jax-ml/jax/issues/24047 using that?
Thanks! We'll close this one and create a new one later to implement the newly proposed mechanism mentioned in the comment above.
@kaixih I've fixed this https://github.com/jax-ml/jax/commit/d5e5b42de85d65fcdc7ab69f38062e00464cf6d9. Let me know if you think its correct.