muzero-general
muzero-general copied to clipboard
Policy target after MCTS should be in form of probabilities
This issue appears only in the implementation of continuous actions version of MuZero.
When computing child visits, we need to divide by sum_visits in order to be in probabilities form.
But, it seems like you forget to divide by sum_visits. Here is the current implementation
sum_visits = sum(child.visit_count for child in root.children.values())
self.child_visits.append(
numpy.array([child.visit_count for child in root.children.values()])
)
I think the correct is the following:
sum_visits = sum(child.visit_count for child in root.children.values())
self.child_visits.append(
numpy.array([child.visit_count / sum_visits for child in root.children.values()])
)
Yes, bro, I also noticed this problem, which causes the KL loss to be less than 0.
The target_policy_action is calculated by action value, and the current calculation of KL loss is not for two distributions.
log_prob = dist.log_prob(target_policy_action[:, i, :]).sum(1)
policy_loss += torch.exp(log_prob) * (log_prob - log_target)