jax
jax copied to clipboard
perf(Pallas): Controlling accumulation dtype for `dot_general` and Pallas `dot`
Triton dot out_dtype default is float32 (see here)
Currently, we do not lower any out_dtype
However, perhaps it needs to be lowered based on lax.dot_general(preferred_element_type=).
Hence, performance cannot be tweaked (see e.g. PyTorch benchmarks)
Note that for whatever reason (precision or otherwise), flash_attn accumulates elements in float32 (see here)
Near duplicate: https://github.com/google/jax/issues/15683
I think the temporary solution here is to add an out_type to pl.dot and pass that into lax.dot(preferred_element_type=...).
Yes, I did that locally. Worked somewhat well except that I needed to force the outdtype conversion with dot(..., out_dtype=my_dtype).astype(my_dtype).
Would be nice to remove the extra verbosity of .astype(my_dtype)
I think dot is coerced to jnp.float32 somewhere (let me look for it). Perhaps we don't need to be that strict if we set preferred_element_type.