open_spiel icon indicating copy to clipboard operation
open_spiel copied to clipboard

Alpha zero refactor (testing and polishing)

Open alexunderch opened this issue 4 months ago • 16 comments

Hello @lanctot !

We @parvxh, me, and @harmanagrawal present our first parts of the AlphaZero refactor. We (with the major help of the guys) have rewritten the models, using flax.linen and flax.nnx (not full support yet, but we'll fulfil it in the nearest future). Moreover, we added a replay buffer compatible with the gpu execution.

Problems that we've faced with:

  • multiprocessing vs jax combo: as jax operates in the multithreaded mode, it doesn't work in the fork method, thus we had to overwrite it. However, there are still spontaneous idle sections in the execution, which may be connected with not that clear usage of synchronisation primitives.

With this pr, we want to know if we're on the right direction and want to contribute to the persistent problem and not hold the solution process behind the closed doors.

alexunderch avatar Aug 17 '25 11:08 alexunderch

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar Aug 17 '25 11:08 google-cla[bot]

Awesome, thanks!

lanctot avatar Aug 18 '25 12:08 lanctot

Hi @alexunderch @parvxh @harmanagrawal ,

Just a quick question: why is it labeled "the first major part"? Because flax.nnx isn't fully working?

What's missing? Does this currently work with flax.linen? I'd really like to see a few graphs in a small game convincingly learning. Would that be possible to show? How long would it take, on e.g. Tic-Tac-Toe?

lanctot avatar Aug 19 '25 11:08 lanctot

Yes, exacltly.

  1. flax.nnx is not fully working, there are some minor fixes (e.g. conv dimension calculation and tests) left
  2. the implementation does work, however, there happen some problems due to multiprocessing (one of the process stales)

The lads have their priories of submitting the pr, that's why we did it, but we'll finish all the fixes in the next week or two. Same with benchmarking for tic-tac-toe.

Appreciate your patience.

alexunderch avatar Aug 19 '25 11:08 alexunderch

I see. Ok, I'll wait until it's fully working to take a deeper look. Feel free to use this PR branch to collaborate amongst yourselves (i.e. update as necessary) and then just let me know when you think it's ready. Thanks for doing this!

lanctot avatar Aug 19 '25 12:08 lanctot

Surely, we will notify you!

alexunderch avatar Aug 19 '25 12:08 alexunderch

@lanctot, for both APIs: linen and nnx there are now tests, and they're passing. The only minor things left from the development side are ~~model export and the changelog~~ benchmarks.

I've run only one benchmark on cpu and a ridiculously small replay buffer/batch size of 256/2. Figure_1

P.s. need to fix rendering as well xD.

The code runs fine, so if you have an ability to run a test or 2 on a gpu, will be fire. I will run them by the weekend, I hope.

alexunderch avatar Aug 27 '25 15:08 alexunderch

@lanctot , I ran for much longer time a TTT experiment telegram-cloud-photo-size-2-5301266743655791235-y

Doesn't look good, does it? Can the picture tell you what I can look at to find bugs?

alexunderch avatar Aug 29 '25 10:08 alexunderch

I'm not sure. I will solicit the help of @tewalds : Timo, does this look like what you expect?

The top left one looks ok...? Should they move closer to 0, though, I'm not sure. Maybe not for a small MLP? So far I'm not too concerned about that one.

I'm not sure what the top two are showing yet (and what the differences are between 0 to 6), but I would expect accuracy to go up over time.

The top-right one: I would expect would go up over time, but it doesn't seem to..? (but the one from two days ago does -- maybe it's because it's learning to draw against MCTS-100 really fast, which is possible and would be good)

Can you explain what you did differently from the graph two days ago? The one from two days ago seemed like a very small (too small) replay buffer and batch size. Did you increase those in the second run? Also how many MCTS simulations are you using?

Also you said you let it run for longer, but I see roughyl the same number of steps on the x-axis.

First step would be to, every epoch or training step, play 10 games against random (as P0) and 10 games against random as P1, and dump those.. and let's inspect the games. We can also track the values versus random over time. If those don't go up then there's something very seriously wrong, but I roughly know what that graph should look like.

lanctot avatar Aug 29 '25 11:08 lanctot

@lanctot The main difference between graphs is in buffer/batch size: 2 ** 16 and 2 ** 10, that were default values for the model

I use default value of averaging of 100 games, 300 simulations each:

{
  "actors": 2,
  "checkpoint_freq": 20,
  "device": "cpu",
  "eval_levels": 7,
  "evaluation_window": 100,
  "evaluators": 1,
  "game": "tic_tac_toe",
  "learning_rate": 0.001,
  "max_simulations": 300,
  "max_steps": 0,
  "nn_api_version": "linen",
  "nn_depth": 10,
  "nn_model": "mlp",
  "nn_width": 128,
  "observation_shape": [
    3,
    3,
    3
  ],
  "output_size": 9,
  "path": "checkpoints",
  "policy_alpha": 1.0,
  "policy_epsilon": 0.25,
  "quiet": true,
  "replay_buffer_reuse": 3,
  "replay_buffer_size": 65536,
  "temperature": 1.0,
  "temperature_drop": 10,
  "train_batch_size": 1024,
  "uct_c": 2,
  "weight_decay": 0.0001
}

alexunderch avatar Aug 29 '25 11:08 alexunderch

will share some progress tomorrow, you may approve checks later <3

alexunderch avatar Sep 01 '25 20:09 alexunderch

I guess, we're making slight progress, do we not?

telegram-cloud-photo-size-2-5314378076219637199-y

give it a look @lanctot

alexunderch avatar Sep 03 '25 19:09 alexunderch

Hey @alexunderch a bit swamped at the moment. I will need the help of @tewalds. I have emailed him and says he can take a look at some poi t but may require a call / catch up. I will be in touch. This will be slow on our side, sorry

lanctot avatar Sep 03 '25 23:09 lanctot

The latest plots (minor tweaks and fixes here and there). Maybe, using much more resources (I used a toy config), there is smth here: image

alexunderch avatar Sep 27 '25 21:09 alexunderch

I found an example with hyperparameters for tic-tac-toe, and results look somewhat more intuitive (although, I had to reduce batch size fourfold due to the resource constraints) Figure_linen Figure_nnx

alexunderch avatar Oct 07 '25 20:10 alexunderch

@lanctot, now ttt's example works much better (see graphs, winrate is pretty nice for an example config). However, I want to test with more steps for connect4 to make sure that everything fully works, like a couple more days, literally.

image

P.s. I am not sure that in the original implementation L2 regularisation was implemented correcly, because it shouldn't work with Adam optimiser, only with SGD...

alexunderch avatar Dec 12 '25 18:12 alexunderch

I remain confused -- I would expect value prediction accuracy to go up over time, but it only does for two of the runs?

lanctot avatar Dec 24 '25 11:12 lanctot

Yes, this is like the only thing that still puzzles me: why value accuracy goes up only for several inconsequtive runs... I now try different hyperparameters to isolate the issue

alexunderch avatar Dec 24 '25 11:12 alexunderch