mosmos6

Results 2 comments of mosmos6

Maybe this works. `pip install dm-haiku==0.0.5` and put optax back to the default version.

> [#202 (comment)](https://github.com/kingoflolz/mesh-transformer-jax/issues/202#issuecomment-1050887576) Please follow this solution! It works. In addition to this, `pip install chex==0.1.2` `pip install jaxlib==0.1.74` `pip install dm-haiku==0.0.5` and it worked for me.