iree-jax icon indicating copy to clipboard operation
iree-jax copied to clipboard

iree-jax fail when exporting bfloat16 parameters

Open wangkuiyi opened this issue 1 year ago • 0 comments

When I run the following program

import jax
import jax.numpy as jnp
from iree.jax import Program, store_global
import flax

model = flax.linen.Dense(
    features=1, use_bias=False, dtype=jnp.bfloat16, param_dtype=jnp.bfloat16
)
rng = jax.random.PRNGKey(0)
model_state = model.init(rng, jnp.ones((1, 1)))


# The generated MLIR module name will be the prefix before Program.
class TryTrainStateProgram(Program):
    _params = Program.export_global(
        model_state["params"], initialize=True, mutable=True
    )

    def get_params(self):
        return self._train_state.params


with open("/tmp/a.mlir", "w") as f:
    f.write(str(Program.get_mlir_module(TryTrainStateProgram)))

I got

ValueError: cannot include dtype 'E' in a buffer

After changing jnp.bfloat16 into other types like jnp.float32, the error disappeared.

wangkuiyi avatar May 25 '23 04:05 wangkuiyi