paxml icon indicating copy to clipboard operation
paxml copied to clipboard

Use bfloat16 for eval

Open tbaker2 opened this issue 5 months ago • 1 comments

I'm running paxml on an Intel Xeon CPU server using the paxml/main.py program. I'm trying to create a model that creates weights in bfloat16, and uses that datatype during eval. I modified the LmCloudSpmd2B configuration with the following lines:

MODEL_DTYPE = jnp.bfloat16
ICI_MESH_SHAPE = [1, 1, 1]

The training status output includes the following output.

model.dtype : type/jax.numpy/float32
model.fprop_dtype : dtype[bfloat16]

All of the other operator datatypes are float32. When I run that model with the --eval switch all of the computation is in float32. How can I direct paxml to use bfloat16?

Tom

tbaker2 avatar Jan 26 '24 23:01 tbaker2