Add Scaled Dot Product Attention for FP8
Scaled Dot Product Attention for FP8. @kaixih @zhangqiaorjc
@wenscarl is this ready to review? I can help with that. If possible, can you resolve the conflicts and run the pre-commit for the format consistency?
It's ready for review and conflict fixed. Thanks!
Compact all changes into fused_attention_stablehlo.py
It looks like this PR forked jax/jax/_src/cudnn /fused_attention_stablehlo.py and then modified it for FP8. Does that sound right? If yes, could we instead add support for FP8 to jax/jax/_src/cudnn /fused_attention_stablehlo.py directly?
Merged into fused_attention_stablehlo.py. Gentle ping @superbobry :)
gently ping @mattjj and @superbobry
Hey @wenscarl, sorry for the silence. I'm a bit swamped this week, but I will try to look through the PR over the weekend or first thing next week.
Hey @wenscarl, sorry for the silence. I'm a bit swamped this week, but I will try to look through the PR over the weekend or first thing next week. Gently ping @superbobry . Thanks!
Thanks, can you squash the commits please?
Thanks, can you squash the commits please?
Done.