on-policy copied to clipboard
__init__() got multiple values for argument 'device'
I run ./onpolicy/scripts/train_mpe_scripts/train_mpe_spread.sh after change 'algo' to mappo and user_name to my wandb user name in train_mpe_spread.sh. My train_mpe_spread.sh is as follows:
algo="mappo" #"rmappo" "ippo"
echo "env is ${env}, scenario is ${scenario}, algo is ${algo}, exp is ${exp}, max seed is ${seed_max}"
for seed in `seq ${seed_max}`;
echo "seed is ${seed}:"
CUDA_VISIBLE_DEVICES=0 python ../train/train_mpe.py --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} \
--scenario_name ${scenario} --num_agents ${num_agents} --num_landmarks ${num_landmarks} --seed ${seed} \
--n_training_threads 1 --n_rollout_threads 128 --num_mini_batch 1 --episode_length 25 --num_env_steps 20000000 \
--ppo_epoch 10 --use_ReLU --gain 0.01 --lr 7e-4 --critic_lr 7e-4 --wandb_name "xxx" --user_name "my_wandb_name"
Then I got an error
Traceback (most recent call last):
File "/home/drl/Projects/on-policy/onpolicy/scripts/train_mpe_scripts/../train/train_mpe.py", line 174, in <module>
File "/home/drl/Projects/on-policy/onpolicy/scripts/train_mpe_scripts/../train/train_mpe.py", line 158, in main
runner = Runner(config)
File "/home/drl/Projects/on-policy/onpolicy/runner/shared/mpe_runner.py", line 14, in __init__
super(MPERunner, self).__init__(config)
File "/home/drl/Projects/on-policy/onpolicy/runner/shared/base_runner.py", line 80, in __init__
self.policy = Policy(self.all_args,
TypeError: __init__() got multiple values for argument 'device'
I found on-policy/onpolicy/runner/shared/base_runner.py, line 66~71 choose R_MAPPOPolicy as Policy:
if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
from onpolicy.algorithms.mat.mat_trainer import MATTrainer as TrainAlgo
from onpolicy.algorithms.mat.algorithm.transformer_policy import TransformerPolicy as Policy
from onpolicy.algorithms.r_mappo.r_mappo import R_MAPPO as TrainAlgo
from onpolicy.algorithms.r_mappo.algorithm.rMAPPOPolicy import R_MAPPOPolicy as Policy
The constructor of R_MAPPOPolicy is
def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device("cpu")):
but on-policy/onpolicy/runner/shared/base_runner.py, line 80~85 provide incompatible arguments to the constructor of R_MAPPOPolicy. The second last one is redundant.
self.policy = Policy(self.all_args,
self.num_agents, # default 2
device = self.device)
What should I do? Thanks!
Same issue It appears to be related to a mismatch in the passed parameters. One potential, albeit unsafe, solution would be to simply remove self.num_agents
OK, thanks!
I found another piece of code similar with this problem. Code snippet onpolicy/runner/shared/base_runner.py, line 90~93 are as follows
if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
self.trainer = TrainAlgo(self.all_args, self.policy, self.num_agents, device = self.device)
self.trainer = TrainAlgo(self.all_args, self.policy, device = self.device)
The code define TrainAlgo together with on-policy/onpolicy/runner/shared/base_runner.py, line 66~71 are
if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
from onpolicy.algorithms.mat.mat_trainer import MATTrainer as TrainAlgo
from onpolicy.algorithms.mat.algorithm.transformer_policy import TransformerPolicy as Policy
from onpolicy.algorithms.r_mappo.r_mappo import R_MAPPO as TrainAlgo
from onpolicy.algorithms.r_mappo.algorithm.rMAPPOPolicy import R_MAPPOPolicy as Policy
The TrainAlgo class has a very similar definition compared with the Policy class. Here the author just use an if statement to solve this problem. I think the aforementioned problem can be solved by add an if statement at the calling of Policy's constructor, too. Maybe one possible solution is changing on-policy/onpolicy/runner/shared/base_runner.py, line 80~85 to
if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
self.policy = Policy(self.all_args, self.envs.observation_space[0], share_observation_space, self.envs.action_space[0], self.num_agents, device = self.device)
self.policy = Policy(self.all_args, self.envs.observation_space[0], share_observation_space, self.envs.action_space[0], device = self.device)