muzero-general icon indicating copy to clipboard operation
muzero-general copied to clipboard

Why is root.visit_count initialized to 0 and root_predicted_value not included in root node value?

Open dniku opened this issue 2 years ago • 0 comments

The MCTS implementation here works roughly like this (pseudocode):

def mcts(observation):
    root_predicted_value, stuff = model.initial_inference(observation)
    root = Node()
    root.expand(stuff)
    root.add_exploration_noise()

    for _ in range(num_simulations):
        leaf = find_unexpanded_leaf()  # here UCB formula depends on root.visit_count
        leaf_predicted_value, stuff = model.recurrent_inference(leaf.hidden_state)
        leaf.expand(stuff)

        value = leaf_predicted_value
        for node in reversed([root, ..., leaf]):
            node.value_sum += value
            node.visit_count += 1
            value = node.reward + discount * value

# ... later in store_search_statistics() ...
game_history.root_values.append(root.value_sum / root.visit_count)

Note that each call to expand() updates root value and root visit count — except for the very first one on root itself. There are two consequences to this:

  • When searching for an unexpanded leaf for the first time, prior probabilities are discarded because the UCB formula includes root visit counts in the numerator: https://github.com/werner-duvaud/muzero-general/blob/23a1f6910e97d78475ccd29576cdd107c5afefd2/self_play.py#L391-L393 This makes the first MCTS simulation less effective than it could be, and may harm performance if the budget for simulations is limited. MuZero paper's first author commented here that this is indeed a problem (Ctrl+F "When selecting among actions of the root, the root's visit count should already be 1.").
  • root_predicted_value does not affect root value. The root value that is eventually stored in game_history.root_values in store_search_statistics(), and it could be made more precise by taking root_predicted_value into account.

A potential fix would be to include

self.backpropagate([root], root_predicted_value, min_max_stats)

right after https://github.com/werner-duvaud/muzero-general/blob/23a1f6910e97d78475ccd29576cdd107c5afefd2/self_play.py#L303-L309

dniku avatar Feb 20 '22 16:02 dniku