muzero-general
muzero-general copied to clipboard
Why is root.visit_count initialized to 0 and root_predicted_value not included in root node value?
The MCTS implementation here works roughly like this (pseudocode):
def mcts(observation):
root_predicted_value, stuff = model.initial_inference(observation)
root = Node()
root.expand(stuff)
root.add_exploration_noise()
for _ in range(num_simulations):
leaf = find_unexpanded_leaf() # here UCB formula depends on root.visit_count
leaf_predicted_value, stuff = model.recurrent_inference(leaf.hidden_state)
leaf.expand(stuff)
value = leaf_predicted_value
for node in reversed([root, ..., leaf]):
node.value_sum += value
node.visit_count += 1
value = node.reward + discount * value
# ... later in store_search_statistics() ...
game_history.root_values.append(root.value_sum / root.visit_count)
Note that each call to expand()
updates root value and root visit count — except for the very first one on root
itself. There are two consequences to this:
- When searching for an unexpanded leaf for the first time, prior probabilities are discarded because the UCB formula includes root visit counts in the numerator: https://github.com/werner-duvaud/muzero-general/blob/23a1f6910e97d78475ccd29576cdd107c5afefd2/self_play.py#L391-L393 This makes the first MCTS simulation less effective than it could be, and may harm performance if the budget for simulations is limited. MuZero paper's first author commented here that this is indeed a problem (Ctrl+F "When selecting among actions of the root, the root's visit count should already be 1.").
-
root_predicted_value
does not affect root value. The root value that is eventually stored ingame_history.root_values
instore_search_statistics()
, and it could be made more precise by takingroot_predicted_value
into account.
A potential fix would be to include
self.backpropagate([root], root_predicted_value, min_max_stats)
right after https://github.com/werner-duvaud/muzero-general/blob/23a1f6910e97d78475ccd29576cdd107c5afefd2/self_play.py#L303-L309