Alpha zero refactor (testing and polishing)
Hello @lanctot !
We @parvxh, me, and @harmanagrawal present our first parts of the AlphaZero refactor. We (with the major help of the guys) have rewritten the models, using flax.linen and flax.nnx (not full support yet, but we'll fulfil it in the nearest future). Moreover, we added a replay buffer compatible with the gpu execution.
Problems that we've faced with:
multiprocessingvsjaxcombo: asjaxoperates in the multithreaded mode, it doesn't work in theforkmethod, thus we had to overwrite it. However, there are still spontaneous idle sections in the execution, which may be connected with not that clear usage of synchronisation primitives.
With this pr, we want to know if we're on the right direction and want to contribute to the persistent problem and not hold the solution process behind the closed doors.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Awesome, thanks!
Hi @alexunderch @parvxh @harmanagrawal ,
Just a quick question: why is it labeled "the first major part"? Because flax.nnx isn't fully working?
What's missing? Does this currently work with flax.linen? I'd really like to see a few graphs in a small game convincingly learning. Would that be possible to show? How long would it take, on e.g. Tic-Tac-Toe?
Yes, exacltly.
flax.nnxis not fully working, there are some minor fixes (e.g. conv dimension calculation and tests) left- the implementation does work, however, there happen some problems due to multiprocessing (one of the process stales)
The lads have their priories of submitting the pr, that's why we did it, but we'll finish all the fixes in the next week or two. Same with benchmarking for tic-tac-toe.
Appreciate your patience.
I see. Ok, I'll wait until it's fully working to take a deeper look. Feel free to use this PR branch to collaborate amongst yourselves (i.e. update as necessary) and then just let me know when you think it's ready. Thanks for doing this!
Surely, we will notify you!
@lanctot, for both APIs: linen and nnx there are now tests, and they're passing. The only minor things left from the development side are ~~model export and the changelog~~ benchmarks.
I've run only one benchmark on cpu and a ridiculously small replay buffer/batch size of 256/2.
P.s. need to fix rendering as well xD.
The code runs fine, so if you have an ability to run a test or 2 on a gpu, will be fire. I will run them by the weekend, I hope.
@lanctot , I ran for much longer time a TTT experiment
Doesn't look good, does it? Can the picture tell you what I can look at to find bugs?
I'm not sure. I will solicit the help of @tewalds : Timo, does this look like what you expect?
The top left one looks ok...? Should they move closer to 0, though, I'm not sure. Maybe not for a small MLP? So far I'm not too concerned about that one.
I'm not sure what the top two are showing yet (and what the differences are between 0 to 6), but I would expect accuracy to go up over time.
The top-right one: I would expect would go up over time, but it doesn't seem to..? (but the one from two days ago does -- maybe it's because it's learning to draw against MCTS-100 really fast, which is possible and would be good)
Can you explain what you did differently from the graph two days ago? The one from two days ago seemed like a very small (too small) replay buffer and batch size. Did you increase those in the second run? Also how many MCTS simulations are you using?
Also you said you let it run for longer, but I see roughyl the same number of steps on the x-axis.
First step would be to, every epoch or training step, play 10 games against random (as P0) and 10 games against random as P1, and dump those.. and let's inspect the games. We can also track the values versus random over time. If those don't go up then there's something very seriously wrong, but I roughly know what that graph should look like.
@lanctot The main difference between graphs is in buffer/batch size: 2 ** 16 and 2 ** 10, that were default values for the model
I use default value of averaging of 100 games, 300 simulations each:
{
"actors": 2,
"checkpoint_freq": 20,
"device": "cpu",
"eval_levels": 7,
"evaluation_window": 100,
"evaluators": 1,
"game": "tic_tac_toe",
"learning_rate": 0.001,
"max_simulations": 300,
"max_steps": 0,
"nn_api_version": "linen",
"nn_depth": 10,
"nn_model": "mlp",
"nn_width": 128,
"observation_shape": [
3,
3,
3
],
"output_size": 9,
"path": "checkpoints",
"policy_alpha": 1.0,
"policy_epsilon": 0.25,
"quiet": true,
"replay_buffer_reuse": 3,
"replay_buffer_size": 65536,
"temperature": 1.0,
"temperature_drop": 10,
"train_batch_size": 1024,
"uct_c": 2,
"weight_decay": 0.0001
}
will share some progress tomorrow, you may approve checks later <3
I guess, we're making slight progress, do we not?
give it a look @lanctot
Hey @alexunderch a bit swamped at the moment. I will need the help of @tewalds. I have emailed him and says he can take a look at some poi t but may require a call / catch up. I will be in touch. This will be slow on our side, sorry
The latest plots (minor tweaks and fixes here and there). Maybe, using much more resources (I used a toy config), there is smth here:
I found an example with hyperparameters for tic-tac-toe, and results look somewhat more intuitive (although, I had to reduce batch size fourfold due to the resource constraints)
@lanctot, now ttt's example works much better (see graphs, winrate is pretty nice for an example config). However, I want to test with more steps for connect4 to make sure that everything fully works, like a couple more days, literally.
P.s. I am not sure that in the original implementation L2 regularisation was implemented correcly, because it shouldn't work with Adam optimiser, only with SGD...
I remain confused -- I would expect value prediction accuracy to go up over time, but it only does for two of the runs?
Yes, this is like the only thing that still puzzles me: why value accuracy goes up only for several inconsequtive runs... I now try different hyperparameters to isolate the issue