jax icon indicating copy to clipboard operation
jax copied to clipboard

lax.dot_general() doesn't support int4 datatype on CPU

Open jianlijianli opened this issue 1 year ago • 2 comments

Encountered the following error running lax.dot_general(int4, int4) on cpu:

INVALID_ARGUMENT: during context [hlo verifier]: S4/U4 is currently only supported in convert instructions, but got instruction with S4/U4 input: %dot.3 = s4[1,256]{1,0} dot(s4[1,3136]{1,0} %Arg_0.1, s4[3136,256]{1,0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(dot_general)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=(<Precision.DEFAULT: 0>, <Precision.DEFAULT: 0>) preferred_element_type=None]"

The same code runs fine on TPU.

jianlijianli avatar Feb 07 '24 08:02 jianlijianli

Hi - thanks for the report! The narrow-width data types like [u]int4 and float8_* are still only experimentally supported, and it's not expected that they will work on all backends. If you need operations that will work regardless of device, it's best to stick to standard data types.

jakevdp avatar Feb 07 '24 18:02 jakevdp

There is a an obvious workaround to upcast to int8. And we have to implement it in all jax.lax.dot_general callsites. Would make sense to put it inside of jax.jax.dot_general?

lukaszlew avatar Feb 16 '24 20:02 lukaszlew