maxtext
maxtext copied to clipboard
Cannot do inference in float32
If we try to perform inference in float32, we get the error:
AssertionError: Key and Value Dtypes should match
This error comes from this line.
The origin of the error is that the cache dtype
is set to jnp.int8 if quantize_kvcache else jnp.bfloat16
but never to jnp.float32
.
What are you setting that triggets this? (Activations to float32?)
Yes it's the dtype: https://github.com/google/maxtext/blob/f52e6f7b49277688f68a58284d1ff7873122ca41/MaxText/configs/base.yml#L61