coach icon indicating copy to clipboard operation
coach copied to clipboard

Categorical DQN - dimension error

Open Colin1998 opened this issue 3 years ago • 0 comments

Hi,

I don't post issues very often, so I hope my problem is clear enough the way I present it below. When trying to train a Categorical DQN (for Batch RL, no interaction with environment), I run into the following error:

_Traceback (most recent call last):

File "", line 129, in graph_manager.improve()

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\graph_managers\batch_rl_graph_manager.py", line 234, in improve self.train()

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\graph_managers\graph_manager.py", line 408, in train [manager.train() for manager in self.level_managers]

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\graph_managers\graph_manager.py", line 408, in [manager.train() for manager in self.level_managers]

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\level_manager.py", line 187, in train [agent.train() for agent in self.agents.values()]

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\level_manager.py", line 187, in [agent.train() for agent in self.agents.values()]

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\agents\agent.py", line 741, in train total_loss, losses, unclipped_grads = self.learn_from_batch(batch)

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\agents\categorical_dqn_agent.py", line 113, in learn_from_batch self.q_values.add_sample(self.distribution_prediction_to_q_values(TD_targets))

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\agents\categorical_dqn_agent.py", line 82, in distribution_prediction_to_q_values return np.dot(prediction, self.z_values)

File "<array_function internals>", line 6, in dot

ValueError: shapes (128,2) and (51,) not aligned: 2 (dim 1) != 51 (dim 0)_

The 2 (dim 1) is the number of actions in my ActionSpace, and the 51 (dim 0) corresponds to the number of atoms set in the agent's parameters. So the error suggests that these should be of equal length, which seems strange to me. Is this indeed true? Should these be of the same length? When setting the numbers of atoms to 2 (to get rid of this error) I got the following error:

_Traceback (most recent call last):

File "", line 129, in graph_manager.improve()

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\graph_managers\batch_rl_graph_manager.py", line 234, in improve self.train()

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\graph_managers\graph_manager.py", line 408, in train [manager.train() for manager in self.level_managers]

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\graph_managers\graph_manager.py", line 408, in [manager.train() for manager in self.level_managers]

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\level_manager.py", line 187, in train [agent.train() for agent in self.agents.values()]

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\level_manager.py", line 187, in [agent.train() for agent in self.agents.values()]

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\agents\agent.py", line 741, in train total_loss, losses, unclipped_grads = self.learn_from_batch(batch)

File "C:\Users\colin.conda\envs\py36\lib\site-packages\rl_coach\agents\categorical_dqn_agent.py", line 116, in learn_from_batch target_actions = np.argmax(self.distribution_prediction_to_q_values(distributional_q_st_plus_1), axis=1)

File "<array_function internals>", line 6, in argmax

File "C:\Users\colin.conda\envs\py36\lib\site-packages\numpy\core\fromnumeric.py", line 1188, in argmax return _wrapfunc(a, 'argmax', axis=axis, out=out)

File "C:\Users\colin.conda\envs\py36\lib\site-packages\numpy\core\fromnumeric.py", line 58, in _wrapfunc return bound(*args, **kwds)

AxisError: axis 1 is out of bounds for array of dimension 1_

I tried setting the axis to zero, but this results in more complex errors, so I assumed this is not the way to go. Does anyone have a clue how I can fix this error? Any suggestions would be of great help, thanks in advance!

Colin1998 avatar May 12 '21 13:05 Colin1998