[cuDNN SDPA] combine mask with bias to support public SDPA API
- cuDNN SDPA does not support mask input any more, therefore we combine the bias and mask manually to align with public SDPA API design.
@superbobry @kaixih could you help review this? @kaixih could you help make sure it is aligned with public API?
I think the minimum cuDNN version JAX supports is 9.0. Does it make sense to remove mask entirely from the JAX API, given that?
CC @hawkinsp
I think the minimum cuDNN version JAX supports is 9.0. Does it make sense to remove mask entirely from the JAX API, given that?
CC @hawkinsp
Removing mask completely should be fine as long as it aligned with the public API.
Please fix the type checker errors.
Thanks, you just need to squash the commits and the PR is good to go.
hi, @superbobry is the pr good to go?