KataGo icon indicating copy to clipboard operation
KataGo copied to clipboard

Extension of the neural network to predict score variance.

Open lukaszlew opened this issue 3 years ago • 5 comments

It is possible to make the neural network predict not only a real number but also uncertainty around it. This is done by replacing typical L2 loss (i.e. log of a normal distribution with a constant variance) with a log of normal distribution where the variance is another output of a neural network. In that setup variance in the data will be captured in NN predictions. For instance when the data has a several rows with the same or similar inputs and varied labels.

There ares several several advantages to that approach.

  • The obvious one is that we can utilize variance prediction as the prior in the search tree.
  • AFAIR we teach the network with the data from the search tree. The tree has a component that the NN does not understand directly (e.g. various search times, tactical complexity). The variance prediction will capture variance in the input. This is exactly what we would want to direct the future search, it might be more effective than the probability distribution over actual moves.
  • Even training of the mean itself will be more efficient. The training will push less gradient into the examples with high variance prediction because L2 is divided by sigma: (pred-x)^2 / sigma.

There are some disadvantages:

  • Doubling of the networks output count increases the size of the last layer.
  • It might require some MCTS tuning that in the end might not work - wasted work.
  • The variance 'head' might prove the policy head obsolete, which when removed can brake some GUIs.

I think it is worth a try, given the conceptual simplicity and clear action items and I'm interested in your thoughts. I can provide more details on the math.

lukaszlew avatar Apr 27 '22 02:04 lukaszlew

KataGo does predicts the score variance of the final game result. This is accomplished by training an auxiliary head to predict the full distribution of final scores (note: the empirical distribution of final scores is often not a normal distribution, and so this head is not constrained to predict only normal distributions). We then compute the variance of the predicted distribution and penalize a second auxiliary head via an L2 loss to agree with the variance of that predicted distribution. The second auxiliary head has no loss term coming from the data itself, only the loss term between it and the first head.

I believe this formulation is theoretically sound - i.e. in the limit of infinite data, minimum loss is achieved when the prediction does in fact equal the score variance of the final game that would result from the distribution of play in training. This is because the first auxiliary head would achieve minimum loss if it exactly matched the empirical distribution of scores that would follow from a given gamestate, and the second auxiliary head would achieve zero loss due to this L2 term if it agreed in each instance exactly with the variance of the distribution predicted by the first head.

By manual inspection in various positions, this head definitely predicts pretty much what it was trained to, but I have not yet found many good uses for it. Can you describe your proposal? Given that this head already exists and already makes what seem to be decent predictions, I assume you have some ideas for how to use it that may not have been tried yet, or perhaps that by "score variance" you mean something other than what this existing head already predicts.

lightvector avatar Apr 27 '22 12:04 lightvector

By the way, if you'd like to play around with this prediction yourself to try to test how it behaves, it is the field scoreStdev documented at https://github.com/lightvector/KataGo/blob/master/docs/Analysis_Engine.md

You can also see what other fields KataGo predicts. There are a few of them.

For a 1-visit analysis request, scoreStdev will be the approximate standard deviation of the predicted final score distribution from the raw neural net. For an greater-than-1-visit request it will be the standard deviation of scores implied by picking a random visit-weighted path down the MCTS tree and then taking into account scoreStdev when hitting a leaf. Due to the way MCTS explores, this tends to bias the value to be higher than than the actual final score distribution standard deviation, but holding visits constant and comparing relative values still meaningfully distinguishes positions with higher final score uncertainty from ones with lower.

lightvector avatar Apr 27 '22 12:04 lightvector

I'm not sure how today the head outputs are used in the tree search, but assuming that is some variant of UCB, then it make sense to add NN predicted standard deviation to the mean an other terms (with some coefficient) to bias the search toward high-variance and uncertain options.

lukaszlew avatar Apr 30 '22 21:04 lukaszlew

Right now the predicted score variance is basically unused for the tree search, because it doesn't seem to be useful. For example, "high score variance" often does not indicate that a position is useful to put extra search effort into, it partly correlates with it, but not ideally well.

Or to put it mathematically, we have different notions of variance:

  • The variance from MCTS's perspective within the UCB algorithm: statistically, how do the score estimates vary between successive playouts within a search? (note: this is nonstationary and changes over time during the search!)
  • The variance of the difference between the current score estimate and the score that would be achieved by optimal play, or by play of an agent with 10000x as much compute as you, (i.e. significantly more optimal than you).
  • The variance of the final score: i.e. statistically, how would the score vary if you in theory did whole-game rollouts all the way to the end of the game.

All three notions are distinct. The second notion of variance is arguably what you care about, but there is no fast estimator for it. And the third notion of variance correlates poorly enough with the first two that I haven't found a good way to use it for anything, even though the net predicts it. Do you have any ideas?

Additionally, the contribution of score to KataGo's utility function is only moderate - it isn't too important anyways for KataGo's playing quality in even games. Strongest play is achieved when the dominant component of the utility function is winrate, not score. So maybe one can't hope for too much using score. You need alternative notions that work more closely with winrate.

One other major obstacle is that if you make any adjustment to the tree search exploration, you have to make a counter-balancing corresponding change to the value averaging of MCTS and the final action selection method. If you don't counterbalance, then you introduce negative bias to the value estimates that recursively the prior parent nodes depend on. For example if you search moves with high variance more, now you will tend to be putting higher weight on low expected value but high uncertainty moves, biasing low the expected value you report to the parent and damaging the quality of the search higher up in the tree. Similarly, you also need to counterbalance the formulas that determine the final policy for what move to choose.

Do you have thoughts on how you would handle the counterbalancing? This last obstacle in particular I'm not bringing up to be contrarian, this is a recurring issue that repeatedly gets in the way of my and many other people's attempts to improve MCTS. Many simple intuitive ideas like "i.e. search more uncertain moves more" or "search forcing moves more", etc. become very hard to implement because of this issue, and I don't know of a general and cheap way to handle it.

Even training of the mean itself will be more efficient. The training will push less gradient into the examples with high variance prediction because L2 is divided by sigma: (pred-x)^2 / sigma.

Can you elaborate on this? It seems like you have some ideas about how to construct the training differently than KataGo does now, which I would consider testing out since right now I'm actually trying to train a new generation of neural nets.

lightvector avatar May 01 '22 06:05 lightvector

As a note on

Do you have thoughts on how you would handle the counterbalancing? This last obstacle in particular I'm not bringing up to be contrarian, this is a recurring issue that repeatedly gets in the way of my and many other people's attempts to improve MCTS. Many simple intuitive ideas like "i.e. search more uncertain moves more" or "search forcing moves more", etc. become very hard to implement because of this issue, and I don't know of a general and cheap way to handle it.

One promising thing I know about is https://arxiv.org/abs/2007.12509 but there are some issues to think about on how this kind of method scales to high playouts (it doesn't handle correlated bias in Q values as well as I would like at high playouts or take into account the uncertainty of the Q value estimates), plus it's computationally a bit expensive CPU-wise to do this computation a lot.

lightvector avatar May 01 '22 06:05 lightvector