jax
jax copied to clipboard
lax.dot_general() doesn't support int4 datatype on CPU
Encountered the following error running lax.dot_general(int4, int4) on cpu:
INVALID_ARGUMENT: during context [hlo verifier]: S4/U4 is currently only supported in convert instructions, but got instruction with S4/U4 input: %dot.3 = s4[1,256]{1,0} dot(s4[1,3136]{1,0} %Arg_0.1, s4[3136,256]{1,0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(dot_general)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=(<Precision.DEFAULT: 0>, <Precision.DEFAULT: 0>) preferred_element_type=None]"
The same code runs fine on TPU.
Hi - thanks for the report! The narrow-width data types like [u]int4
and float8_*
are still only experimentally supported, and it's not expected that they will work on all backends. If you need operations that will work regardless of device, it's best to stick to standard data types.
There is a an obvious workaround to upcast to int8. And we have to implement it in all jax.lax.dot_general callsites. Would make sense to put it inside of jax.jax.dot_general?