jax icon indicating copy to clipboard operation
jax copied to clipboard

Add Scaled Dot Product Attention for FP8

Open wenscarl opened this issue 1 year ago • 4 comments

Scaled Dot Product Attention for FP8. @kaixih @zhangqiaorjc

wenscarl avatar Jul 25 '24 19:07 wenscarl

@wenscarl is this ready to review? I can help with that. If possible, can you resolve the conflicts and run the pre-commit for the format consistency?

kaixih avatar Aug 22 '24 16:08 kaixih

It's ready for review and conflict fixed. Thanks!

wenscarl avatar Aug 25 '24 13:08 wenscarl

Compact all changes into fused_attention_stablehlo.py

wenscarl avatar Sep 17 '24 18:09 wenscarl

It looks like this PR forked jax/jax/_src/cudnn /fused_attention_stablehlo.py and then modified it for FP8. Does that sound right? If yes, could we instead add support for FP8 to jax/jax/_src/cudnn /fused_attention_stablehlo.py directly?

Merged into fused_attention_stablehlo.py. Gentle ping @superbobry :)

wenscarl avatar Oct 18 '24 18:10 wenscarl

gently ping @mattjj and @superbobry

wenscarl avatar Oct 30 '24 19:10 wenscarl

Hey @wenscarl, sorry for the silence. I'm a bit swamped this week, but I will try to look through the PR over the weekend or first thing next week.

superbobry avatar Nov 06 '24 21:11 superbobry

Hey @wenscarl, sorry for the silence. I'm a bit swamped this week, but I will try to look through the PR over the weekend or first thing next week. Gently ping @superbobry . Thanks!

wenscarl avatar Dec 03 '24 04:12 wenscarl

Thanks, can you squash the commits please?

superbobry avatar Dec 13 '24 14:12 superbobry

Thanks, can you squash the commits please?

Done.

wenscarl avatar Dec 13 '24 18:12 wenscarl