muzero-general icon indicating copy to clipboard operation
muzero-general copied to clipboard

sampling in continuous/complex action spaces with 'density prior' is not working

Open ManorZ opened this issue 1 year ago • 0 comments

Search before asking

  • [X] I have searched the MuZero issues and found no similar bug report.

🐛 Describe the bug

In Learning and Planning in Complex Action Spaces (Hubert et al.), there are basically two changes compared to MuZero:

  1. Modify the policy probabilities inside PUCB to be 'sampled policy' (pi_hat = beta_hat/beta * pi)
  2. Sample K actions instead of evaluating all possible actions (infinity in the continuous case)

In the code, I think I see a difference:

  1. No K samples are drawn at the root - only one.
  2. Regarding pi_hat = beta_hat/beta * pi, I see two options there: 'uniform prior' and 'density prior'.
  • Uniform prior gives equal density to all actions, and weights the policy accordingly, and the current code makes sense.
  • Density prior needs to take care of each action CDF (at the parent), but it doesn't work (error message and description below).

Add an example

The error message I get: File "/home/user_231/muzero-general/self_play.py", line 401, in ucb_score child.prior / sum([child.prior for child in parent.children.values()]) TypeError: unsupported operand type(s) for +: 'int' and 'NoneType' Last test reward: 0.00. Training step: 0/3. Played games: 0. Loss: 0.00

This is because no one assigns node.prior in the continuous branch. I think it has to be set by the parent in his expand method and to be equal to each child's CDF, at the sampled point.

Also, regarding the K, I think we need to make a small change in the expand method, and sample more than one action:

action_value = distribution.sample(K).squeeze(0).detach().cpu().numpy() self.children[Action(action_value)] = Node()

Environment

No response

Minimal Reproducible Example

python muzero.py mujoco_IP {"node_prior":"density"}

Additional

No response

ManorZ avatar Jul 09 '22 20:07 ManorZ