open_spiel
open_spiel copied to clipboard
MuZero implementation using OpenSpiel
First of all, I want to thank the developers for this awesome project! It's simple, clean yet powerful. I really enjoyed playing with it.
I'm currently studying at the University of Alberta under the supervision of Prof. Martin Mueller. My primary research focus is learning/planning with a model, and general game playing. MuZero was a big step in this direction and I would like to implement an open-source version of it as a foundation of my project. I'm aware of other open-source implementations but I would like to have a more efficient and robust implementation. (This is partially why I opened #592, since utilizing cloud computation power well is a must-have.) There's also #135 but unfortunately no follow-up.
My plan is to implement MuZero in a separate repo first using both the C++ and Python interfaces of OpenSpiel. I'll use C++ for search (MuZero flavored MCTS) and Python for everything else. I'll use JAX for neural-net-related things (I heard that's what you are using in DeepMind for MuZero). If the project works, we can integrate the project into OpenSpiel at some point in the future.
Here are some questions of mine:
- What are the caveats of writing a MuZero flavored MCTS in C++ similar to OpenSpiel's own MCTS?
- I presume since the interfaces of the games are quite unified, the algorithm should work on all the games without too much tweaking (work in terms of running without errors, but not necessarily performing well on the task). Is this correct?
- What do you do to display/visualize information/metrics? Right now I'm just reading console logs.
It would be great if the developers can help me with these questions and any other tips regarding the project would be greatly appreciated! 😃
Hi @uduse,
This sounds great :)
I'm aware of other open-source implementations but I would like to have a more efficient and robust implementation.
Do you know about muzero-general? They also recently added support for OpenSpiel games: https://github.com/werner-duvaud/muzero-general/commit/23a1f6910e97d78475ccd29576cdd107c5afefd2.
My plan is to implement MuZero in a separate repo first using both the C++ and Python interfaces of OpenSpiel. I'll use C++ for search (MuZero flavored MCTS) and Python for everything else. I'll use JAX for neural-net-related things (I heard that's what you are using in DeepMind for MuZero). If the project works, we can integrate the project into OpenSpiel at some point in the future.
Cool! Am I assuming correctly that you want to have the search/inference in C++ due to performance? (i.e. rather than a Python impl like muzero-general?) We don't have any examples of mixing C++ and JAX in the library yet (not even simple ones), but they'd be more than welcome!
1. What are the caveats of writing a MuZero flavored MCTS in C++ similar to OpenSpiel's own MCTS?
Sorry, I don't really understand.. can you elaborate? Is there something that you're concerned about in particular?
2. I presume since the interfaces of the games are quite unified, the algorithm should work on all the games without too much tweaking (work in terms of running without errors, but not necessarily performing well on the task). Is this correct?
I don't immediately see why you wouldn't be able to keep it general (muzero-general has, AFAIK?)
3. What do you do to display/visualize information/metrics? Right now I'm just reading console logs.
@tewalds wrote an analysis tool described at the bottom of https://github.com/deepmind/open_spiel/blob/master/docs/alpha_zero.md. I think @christianjans has used it as well.
Do you know about muzero-general? They also recently added support for OpenSpiel games: werner-duvaud/muzero-general@23a1f69.
Yes, and my project is heavily inspired by that. My project will be different in two main ways: (1) I'll use JAX (2) I'll try to make the algorithm efficient and scalable.
Cool! Am I assuming correctly that you want to have the search/inference in C++ due to performance?
Yes, similar to the MCTS in KataGo, I will try to implement a multi-threaded C++ version for MuZero.
Sorry, I don't really understand.. can you elaborate? Is there something that you're concerned about in particular?
I was wondering if there's anything on top of your head that you would like to tell me. Other than that, my only concern is getting JAX to work with C++. If I can't easily get C++ to inference the network, I have to use a separate Python inference worker to handle batch inferences from C++ threads. That said, before doing any of that, I need to benchmark muzero-general to see where's the Python self-play bottleneck. Maybe it will be sufficiently fast if I implement a Python multi-threaded (GIL warning ⛔ ) MCTS.
The analysis tool seems great! I will try to utilize it 👍
FYI, a GSoC student will also work on MuZero in Julia and use the OpenSpiel wrapper this summer.
FYI, a GSoC student will also work on MuZero in Julia and use the OpenSpiel wrapper this summer.
Thanks for the information! Do you know this person's contact information? I cold-emailed a person with the same name but I likely got it wrong...
Try this michal.lukomski21
through gmail. And @jonathan-laurent is the primary mentor.
Newly released MuZero implementation that might be of interest: https://github.com/google-research/google-research/tree/master/muzero
Here's the GSOC page for the Julia muZero project: https://michelangelo21.github.io/gsoc/2021/08/23/gsoc-2021.html
I am also interested in implementing MCTS and a semi-grad TD algorithm, to learn the material better. But figuring out the batching is a problem for me. Is there a JAX-based implementation of MCTS (or any kind of tree search that uses a neural network to estimate some feature of the nodes) I can see as a reference?
In general, is the idiomatic approach to use multi-threading and run several JAX JITed functions in parallel, or to create a suitable batch that is fed into a single JAX function?
Hi @NightMachinary,
You can find MCTS that uses neural nets in our AlphaZero implementations: https://github.com/deepmind/open_spiel/blob/master/docs/alpha_zero.md
None are JAX-based. It would be great to have a JAX one too! Would make a nice contribution.
I am also interested in implementing MCTS and a semi-grad TD algorithm, to learn the material better. But figuring out the batching is a problem for me. Is there a JAX-based implementation of MCTS (or any kind of tree search that uses a neural network to estimate some feature of the nodes) I can see as a reference?
In general, is the idiomatic approach to use multi-threading and run several JAX JITed functions in parallel, or to create a suitable batch that is fed into a single JAX function?
My current implementation uses multiple copies of asynchronous MCTSs (using asyncio
) and batch their queries together to inference using the NN. Each single MCTS still perform like a single-threaded MCTS, but combining multiple such MCTSs yields similar throughput as multi-threaded MCTS.
@uduse
Can you link your implementation?
From what I understand, the key points are your design are that
-
MCTS only depends on the value/policy functions defined before it is run (which are not updated during the run), and its output is just the selection count for each node, which we can trivially sum for multiple runs.
-
JAX uses async dispatch, so using it with asyncio will result in an effectively parallel execution.
@NightMachinary
I just made my repo public, see here for my async MCTS.
It might not be what you want. My implementation focuses on increasing the throughput of multiple MCTSs but the latency of each individual MCTS is not reduced.
- MCTS only depends on the value/policy functions defined before it is run (which are not updated during the run), and its output is just the selection count for each node, which we can trivially sum for multiple runs.
Each MCTS outputs its own selection count, multiple runs yields multiple selection counts, which means multiple data points.
- JAX uses async dispatch, so using it with asyncio will result in an effectively parallel execution.
I don't think JAX's async dispatch has anything to do with concurrent MCTS.
Batching is indeed your core problem, but it's not that related to how JAX works. No matter what NN library you use, your batching logic will be separated from the NN inference somewhere else. This means the design of your MCTS only affects how you batch multiple NN inferences together. Once you have those batches, you can use any NN to do efficient inference. The NN in my project is JAX-based. However, I can swap it with a pytorch-based one without changing anything in my MCTS.
I have an implementation for the batching layer here. It's kinda verbose now and I'm planning to simplify it. Also see test here for example usage.
Hi. I wrote a MCTS using tensorflow's tensory only+batching (and XLA). https://gitlab.com/dia.group/tf-muzero So it's possible to the GPU doing the MCTS's stuff (whole MCTS in one gpu function call) ! It could be great to rewrite it to JAX. Difficulties:
- a bit complex to read code (so JAX could help)
- some part have to be re-designed (like MCTS backpropagation), but I've done that stuff
- tree size is fixed and for large tree you may spend time reading/writing to it (in tensoflow the operator scatter/gather) Maybe someone could have a solution for this issue (I need some help)
- tree size is limited by GPU RAM (I tested 1024 Connect4 games / 700 simulations / 128 features on a 11GB GPU) Advantage:
- fast (more than 25k inferences per second)
- no C++ code
@cmarlin There's already a JAX implementation here. They addressed the issues you mentioned above using a couple of tricks. It would be very interesting to make full JAX compatible agents in the future.
@cmarlin There's already a JAX implementation here. They addressed the issues you mentioned above using a couple of tricks. It would be very interesting to make full JAX compatible agents in the future.
Thanks for your link, I didn't know it!
Tagging @tuero based on a recent reddit comment.
Also, I guess this is rather outdated. Is anybody working on an implementation they plan to contribute? Just wondering if we should still keep this issue open for discussion?
@uduse @cmarlin @NightMachinery
Hi @lanctot , a few months ago I did a full C++ implementation, moreso as a learning exercise (and possible extensions). I don't really have any future plans for the codebase, as I've moved onto slightly other methods for my research.
While the code is for the most part complete, I would have to make some changes for it to fit nicely with this repo which might not be trivial, but I am more than happy to do. I'm a bit busy at the moment, so this wouldn't be something for the near future, but can be made into a separate issue/PR once its reasonably close.
@tuero As always, your contributions are super welcome :) But I completely understand the time investment and trade-offs... especially as a grad student!
I'll leave this thread open for now then in case people still want to still use it for discussion.
@lanctot I'm still working on the project. The project became more complicated than I envisioned and it would be extremely difficult to be ported as a part of OpenSpiel. That said, I am still using OpenSpiel games to train it and it would be a nice example of using OpenSpiel in a complicated DRL system.
Ok, cool! That makes complete sense for an algorithm of this complexity, and especially since you are using it for research. Glad you could still use the game implementations. Please feel free to share any results if you want :)
Closing due to inactivity. Please re-open if you want to continue the discussion.