AlphaZero.jl
AlphaZero.jl copied to clipboard
refactoring mcts without using Recursion makes a little/much improvement
According the scripts/profile/inference.jl
, I got the png below, which says on my 2080ti / 28 cores cpu machine, batch size 512 is the best, and it cost about '1.6e-5' per example.
But setting batch_size=512
, and num_workers=1024
in games/connect-four/params.jl
just make it worse,nearly 2X slower than batch_size=64, num_workers=128
, and gpu utilization is between 0-40%, instead of 50% when batch_size=64
.
It takes me several weeks to analyze the problems, and finally I got sth. After I refactor MCTS without using recursion, It seems everything is ok. When batch_size = 64
, it is 0-20% faster than old mcts, and when batch_size=512
it is even 0-20% faster. I've no ideas why, maybe somebody can help.
And I found that, 4 threads for cpu worker
is the best. More threads just make things worse. I think this is sth related to Julia threads/task scheduling algorithms.
Here is my code: https://github.com/magicly/AlphaZero.jl After @jonathan-laurent reviewing, if it's ok, I would like to make a pr.
I tested on three machine, 2080
, 2080ti
, 3080ti
, below is some test data. some data is missing because I don't have enough time to finish the experiments. In fact, after several weeks work, I'm got a little frustrated.
There's still sth need to do: [ ] when batch_size=512, the gpu is 60% used instead of 90%+, and average infer time is 2.7e-5, instead of 1.6e-5 [ ] using fp16 maybe get another 2x improvement
First of all, thank you very much for investigating this. Before we try and merge this work, I would like to understand better why you are observing these improvements.
My first reaction is that this is a bit counterintuitive to me. Indeed, I did not expect making MCTS nonrecursive to make a difference, especially since the typical call stack depth in the current implementation is pretty low.
Is this due to your new MCTS implementation being faster, or is mostly due to a reduced memory consumption? Does this indicate that my version is consumming too much memory by keeping some big objects too long on the stack or does it point to a batching inefficiency? Or is it simply a weird artifact of how the GC or the task scheduling algorithms work?
To understand this better, I suggest that you look at these things:
- Use this new visualization functionality I just developed to investigate the difference between both versions.
- Profile vanilla MCTS to compare both versions?
- Compare the memory footprint of both versions on a short run.
Thanks for your contribution!
ok, I'll have a try. Seems the url is broken, maybe you mean this https://github.com/jonathan-laurent/AlphaZero.jl/issues/76 ~