brax
brax copied to clipboard
brax/training/agents/ppo/train.py fails to JSON serialize the config during checkpoint saving.
I encountered this bug when running the MuJoCo Playground tutorial with the following command:
python learning/train_jax_ppo.py --env_name CartpoleBalance
The above command effectively runs brax/training/agents/ppo/train.py.
I resolved the bug by referring to brax/training/agents/bc/train.py and modifying lines train.py#L601-L611 as follows:
specs_obs_shape = jax.tree_util.tree_map(
lambda x: specs.Array(x.shape[-1:], jnp.dtype('float32')), env_state.obs
)
training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray
optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars
params=init_params,
normalizer_params=running_statistics.init_state(
_remove_pixels(specs_obs_shape)
),
env_steps=types.UInt64(hi=0, lo=0),
)
The above command terminates with the following error:
Traceback (most recent call last):
File ".pyenv/versions/3.10.16/lib/python3.10/pdb.py", line 1723, in main
pdb._runscript(mainpyfile)
File ".pyenv/versions/3.10.16/lib/python3.10/pdb.py", line 1583, in _runscript
self.run(statement)
File ".pyenv/versions/3.10.16/lib/python3.10/bdb.py", line 598, in run
exec(cmd, globals, locals)
File "<string>", line 1, in <module>
File "work/mujoco_playground/learning/train_jax_ppo.py", line 506, in <module>
app.run(main)
File "work/mujoco_playground/.venv/lib/python3.10/site-packages/absl/app.py", line 316, in run
_run_main(main, args)
File "work/mujoco_playground/.venv/lib/python3.10/site-packages/absl/app.py", line 261, in _run_main
sys.exit(main(argv))
File "work/mujoco_playground/learning/train_jax_ppo.py", line 432, in main
make_inference_fn, params, _ = train_fn( # pylint: disable=no-value-for-parameter
File "work/mujoco_playground/.venv/lib/python3.10/site-packages/brax/training/agents/ppo/train.py", line 731, in train
checkpoint.save(
File "work/mujoco_playground/.venv/lib/python3.10/site-packages/brax/training/agents/ppo/checkpoint.py", line 35, in save
return checkpoint.save(path, step, params, config, _CONFIG_FNAME)
File "work/mujoco_playground/.venv/lib/python3.10/site-packages/brax/training/checkpoint.py", line 136, in save
config_path.write_text(config.to_json())
File "work/mujoco_playground/.venv/lib/python3.10/site-packages/ml_collections/config_dict/config_dict.py", line 1136, in to_json
return self._json_dumps_wrapper(cls=json_encoder_cls, **kwargs)
File "work/mujoco_playground/.venv/lib/python3.10/site-packages/ml_collections/config_dict/config_dict.py", line 1109, in _json_dumps_wrapper
return json.dumps(self, **kwargs)
File ".pyenv/versions/3.10.16/lib/python3.10/json/__init__.py", line 238, in dumps
**kw).encode(obj)
File ".pyenv/versions/3.10.16/lib/python3.10/json/encoder.py", line 199, in encode
chunks = self.iterencode(o, _one_shot=True)
File ".pyenv/versions/3.10.16/lib/python3.10/json/encoder.py", line 257, in iterencode
return _iterencode(o, 0)
File "work/mujoco_playground/.venv/lib/python3.10/site-packages/ml_collections/config_dict/config_dict.py", line 1970, in default
raise TypeError('{} is not JSON serializable. Instead use '
TypeError: <class 'brax.training.acme.specs.Array'> is not JSON serializable. Instead, use ConfigDict.to_json_best_effort()
This error occurs because the type of obs_shape is 'brax.training.acme.specs.Array', which is not JSON serializable, and it is included in the config when saving the checkpoint.
Interestingly, although this was once addressed in the following Pull Request and it reverted in this commit to this. I assume the original PR was reverted because it did not address the root cause.
Version info: python: 3.10.16 brax: 0.12.4 ml_collections: 1.1.0
Sorry, I just realized that this is a complete duplicate of the following bug report: https://github.com/google/brax/issues/607
Most likely, the issue #607 can also be resolved by the fix proposed above.