rlpd
rlpd copied to clipboard
Flax FrozenDict: dict.copy() takes no keyword arguments
Reproduce error
- flax 0.7.5
- jaxlib 0.4.21+cuda12.cudnn89
- Ubuntu 22.04
Running
XLA_PYTHON_CLIENT_PREALLOCATE=false python train_finetuning_pixels.py --env_name=cheetah-run-v0 \
--start_training 5000 \
--max_steps 300000 \
--config=configs/rlpd_pixels_config.py \
--project_name=rlpd_vd4rl
I am getting: TypeError: dict.copy() takes no keyword arguments.
Possible fix
In file rlpd/rlpd/agents/drq/drq_learner.py:
import flax.core.frozen_dict as frozen_dict
actor_params = frozen_dict.FrozenDict(actor_def.init(actor_key, observations)["params"]) # line 121
critic_params = frozen_dict.FrozenDict(critic_def.init(critic_key, observations, actions)["params"]) # line 145
I believe this might be related to Flax's migration from frozen_dict to regular Python dictionaries as the return type, according to the issue here. Note the migration note here for Flax 0.7.1 onwards. Not sure what exact lines are erroring for you, but another possible workaround could be using the flax.core.frozen_dict utility functions, described here.
Hope this is helpful!