open_spiel icon indicating copy to clipboard operation
open_spiel copied to clipboard

Single player AlphaZero (python)

Open FlynnDowey opened this issue 6 months ago • 8 comments

Is there a way to make AlphaZero work with single player games? If so, how could I implement it?

Could I make my single player game, into a two player game and never change the current player and return a duplicate value of the returns?

FlynnDowey avatar Jun 08 '25 23:06 FlynnDowey

Yes, it should work.

However the python AlphaZero based has been deprecated because we no longer support tensorflow. We have a small team of external contributors porting it to JAX.

When it is ready, we can check/ensure that it works with single player games.

lanctot avatar Jun 09 '25 08:06 lanctot

@lanctot does the python version of alpha zero support GPU usage?

FlynnDowey avatar Jun 10 '25 00:06 FlynnDowey

The old version did, yes.. but like I said, it's been deprecated because it no longer works. See https://github.com/google-deepmind/open_spiel/issues/1326. The files have been moved here: https://github.com/google-deepmind/open_spiel/tree/lanctot-patch-67/open_spiel/python/algorithms/tf

We'll see if the new version supports GPUs. It should at least be easier in JAX.

@alexunderch do you know if you'll test the GPU use case?

lanctot avatar Jun 10 '25 01:06 lanctot

Hey! @FlynnDowey

jax is a device-agnostic framework (run the same code on different devices: cpu, gpu, etc.), and thus, the implementation works on gpu, unless you have a corresponding version installed.

However, we should notice, that the algorithm uses gpu only to train neural networks (for policy, value, etc.). The implementation of MCTS is naïve and uses multiprocessing on cpu. We will bring you some numbers on the memory consumption and time overhead in some time. Hope, it'll satisfy you.

Do you have any time to wait (couple of weeks) or want to run the code ASAP? For some turn based games, there exists a simple third party implementation: link

Moreover, there exists a native Deepmind's implementation of MCTS, mctx. Even if the code is for mainly educational purposes, we could really consider making it work for the library, @lanctot? However, it may require some intricacies to make it done.

alexunderch avatar Jun 10 '25 05:06 alexunderch

Hi @alexunderch, I need it for educational purposes. I see that mctx is written in Jax as you said. Does mctx work with games types in open_spiel, or will I need to re-write a new environment class?

FlynnDowey avatar Jun 20 '25 00:06 FlynnDowey

Hey! It should work with the games. You need to rewrite a mcts bot. I can write in more detail, if you want.

alexunderch avatar Jun 20 '25 06:06 alexunderch

@alexunderch please provide a more detailed explanation. Thanks.

FlynnDowey avatar Jun 20 '25 14:06 FlynnDowey

Hey, dear @FlynnDowey. I am back and embarrassingly late. I did find a pipeline how to pair the library with mctx but found also some problems @lanctot may (or rather not) find interesting to solve.

Firstly, let's do the implementation explicitly without doing any bots:

You basically need to implement 4 basic steps:

  1. how the root of the tree is being computed:
   def _root_fn(params, state):
      obs = batch(state.observation_tensor())
      #both, policy and value should be batched [B, ...]
      #but we work with the simplest setting, so B = 1
      policy_logits, values = bot({"params": params}, obs.astype(float))
      return mctx.RootFnOutput(
        prior_logits=policy_logits,
        value=values,
        embedding=[state]
      )

And here is the 1st problem: it would be cool if games in the library were batchable, i.e. you could create more than 1 copy of parallel envs. Now, we have only one.

  1. how tree traversal is computed:
def get_recurrent_fn(bot, state, discount):
    
    def recurrent_fn(params, key, action, embedding):
      state = embedding[0]

      new_state = jax.pure_callback(state.apply_action, action[0])
    
      obs = batch(new_state.observation_tensor())
      
      done = 1.0-jnp.array(state.is_terminal()).reshape(-1)
      #both, policy and value should be batched [B, ...]
      #but we work with the simplest setting, so B = 1
      policy_logits, values = bot({"params": params}, obs.astype(float))

      recurrent_fn_output = mctx.RecurrentFnOutput(
        reward=jnp.array(new_state.rewards()[new_state.current_player()]).reshape(-1),
        discount = done * discount,
        prior_logits=policy_logits,
        value=values * done
      )
      return recurrent_fn_output, [new_state]
    
    return recurrent_fn

Here we have the 2nd problem: this function is getting jit compiled, thus, it should be converted to the spiel-acceptable float format using callbacks new_state = jax.pure_callback(state.apply_action, action[0]). But it's fine. ** The main thing is: to make the tree traversable, game state need to be pytree-compatible. It's not now, that's why you can't make the implementation fully work **. Sounds like an interesting feature request, @lanctot ?

  1. how tree policy is computed:
bot = bots[state.current_player()]

      root_fn = get_root_fn(bot.apply_fn) 
      root = root_fn(bot.params, state)

      recurrent_fn=get_recurrent_fn(bot.apply_fn, state, discount)

      policy_output = mctx.gumbel_muzero_policy(
        params=bot.params,
        rng_key=rng,
        root=root,
        recurrent_fn=recurrent_fn,
        num_simulations=40, #toy version
        max_num_considered_actions=temperature_drop,
        gumbel_scale=1.0, #perfect information games
        max_depth=10, #fixed for the toy example
        qtransform=functools.partial(
          mctx.qtransform_completed_by_mix_value,
            use_mixed_value=use_mixed_value,
            value_scale=value_scale
          ),
        )

      search_policy = policy_output.action_weights
      search_value = policy_output.search_tree.node_values[:, policy_output.search_tree.ROOT_INDEX]
      action = policy_output.action

That's it. If you come up how to convert the env implementation into pytree (container of vectors) form, the code should work fine. Sorry for being that slow and disappointing. You can try the method using simpler versions of environments like:

  • https://github.com/kenjyoung/mctx_learning_demo/tree/main
  • https://github.com/sotetsuk/pgx/tree/main/examples/alphazero

Sincerely, Sacha.

P.s here is the whole gist if it helps: https://gist.github.com/alexunderch/df5ad14b85f9fe749318a4dafbfba064

alexunderch avatar Jun 30 '25 19:06 alexunderch

New effort to support AlphaZero in JAX: https://github.com/google-deepmind/open_spiel/pull/1362

lanctot avatar Aug 20 '25 11:08 lanctot