iree-jax
iree-jax copied to clipboard
iree-jax fail when exporting bfloat16 parameters
When I run the following program
import jax
import jax.numpy as jnp
from iree.jax import Program, store_global
import flax
model = flax.linen.Dense(
features=1, use_bias=False, dtype=jnp.bfloat16, param_dtype=jnp.bfloat16
)
rng = jax.random.PRNGKey(0)
model_state = model.init(rng, jnp.ones((1, 1)))
# The generated MLIR module name will be the prefix before Program.
class TryTrainStateProgram(Program):
_params = Program.export_global(
model_state["params"], initialize=True, mutable=True
)
def get_params(self):
return self._train_state.params
with open("/tmp/a.mlir", "w") as f:
f.write(str(Program.get_mlir_module(TryTrainStateProgram)))
I got
ValueError: cannot include dtype 'E' in a buffer
After changing jnp.bfloat16
into other types like jnp.float32
, the error disappeared.