jax icon indicating copy to clipboard operation
jax copied to clipboard

[cuDNN SDPA] combine mask with bias to support public SDPA API

Open Cjkkkk opened this issue 1 year ago • 5 comments

  • cuDNN SDPA does not support mask input any more, therefore we combine the bias and mask manually to align with public SDPA API design.

Cjkkkk avatar Jun 25 '24 06:06 Cjkkkk

@superbobry @kaixih could you help review this? @kaixih could you help make sure it is aligned with public API?

Cjkkkk avatar Jun 25 '24 06:06 Cjkkkk

I think the minimum cuDNN version JAX supports is 9.0. Does it make sense to remove mask entirely from the JAX API, given that?

CC @hawkinsp

superbobry avatar Jun 27 '24 16:06 superbobry

I think the minimum cuDNN version JAX supports is 9.0. Does it make sense to remove mask entirely from the JAX API, given that?

CC @hawkinsp

Removing mask completely should be fine as long as it aligned with the public API.

Cjkkkk avatar Jun 28 '24 19:06 Cjkkkk

Please fix the type checker errors.

superbobry avatar Jun 30 '24 20:06 superbobry

Thanks, you just need to squash the commits and the PR is good to go.

superbobry avatar Jul 01 '24 17:07 superbobry

hi, @superbobry is the pr good to go?

Cjkkkk avatar Jul 05 '24 17:07 Cjkkkk