jax icon indicating copy to clipboard operation
jax copied to clipboard

perf(Pallas): Controlling accumulation dtype for `dot_general` and Pallas `dot`

Open jon-chuang opened this issue 2 years ago • 3 comments
trafficstars

Triton dot out_dtype default is float32 (see here)

Currently, we do not lower any out_dtype

However, perhaps it needs to be lowered based on lax.dot_general(preferred_element_type=).

Hence, performance cannot be tweaked (see e.g. PyTorch benchmarks)

Note that for whatever reason (precision or otherwise), flash_attn accumulates elements in float32 (see here)

jon-chuang avatar Sep 14 '23 18:09 jon-chuang

Near duplicate: https://github.com/google/jax/issues/15683

jon-chuang avatar Sep 14 '23 20:09 jon-chuang

I think the temporary solution here is to add an out_type to pl.dot and pass that into lax.dot(preferred_element_type=...).

sharadmv avatar Sep 14 '23 21:09 sharadmv

Yes, I did that locally. Worked somewhat well except that I needed to force the outdtype conversion with dot(..., out_dtype=my_dtype).astype(my_dtype).

Would be nice to remove the extra verbosity of .astype(my_dtype)

I think dot is coerced to jnp.float32 somewhere (let me look for it). Perhaps we don't need to be that strict if we set preferred_element_type.

jon-chuang avatar Sep 14 '23 21:09 jon-chuang