open_spiel icon indicating copy to clipboard operation
open_spiel copied to clipboard

Fix Deep CFR PyTorch convergence on Kuhn Poker

Open fuyuan-li opened this issue 4 weeks ago • 8 comments

This PR fixes the PyTorch Deep CFR convergence issue on Kuhn Poker (Issue #1376). The implementation of DeepCFRSolver is correct, but default hyperparameters in this example caused unstable advantage estimation due to high regret variance and insufficient policy network training.

This PR updates: • reinitialize_advantage_networks=False • Increases policy_network_train_steps • Increases num_traversals

Diagnostics confirm that all components train correctly; the issue was variance-driven. With the updated hyperparameters, Deep CFR now converges reliably (40 runs → mean value –0.058 ≈ theoretical –0.056). reproducible_experiment.html

Fixes #1376.

fuyuan-li avatar Dec 11 '25 12:12 fuyuan-li

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar Dec 11 '25 12:12 google-cla[bot]

As discussed in #1376 is it easy enough to change the params in the JAX version as well and check that it works on Kuhn poker in that setting?

lanctot avatar Dec 11 '25 21:12 lanctot

@lanctot wait with this one for one moment (till Monday), I want to check with the original paper, because it doesn't intuitively align why there shall be so many traversals for Kuhn poker, please

alexunderch avatar Dec 12 '25 18:12 alexunderch

Thank you @lanctot -- I realize Jax implementation has a different issue (as @alexunderch pointed out it is another warning from tf dataset, not a convergence issue. I'm happy to take another look at that, maybe we open a separate PR?) @alexunderch -- Thank you so much for taking a second look! (Very happy if this may turn into a discussion on method/implementation, beyond a bug fix.)

Happy to share more that led me to this PR. (See attached for details). In short, I observed 3 things:

  1. With the original hyperparams, learned policy is close to random (50-50 at player 0's infosets). If I use CFR (tabular) as a reference, with the updated hyperparams, policy trained towards the reference results.
  2. The policy network seems under-trained: continuing training with the same data reduces loss by 85%+
  3. Advantage targets have very large variance. Likely inherit to traversal sampling; very open to hear your thoughts. Based on these, I tuned the hyperparams in PR.

Also, while reading the original paper, I noticed some differences but thinking they could be minor (and very reasonable practical change): (i) Algorithm 1 in the paper asked advantage net being initialized s.t. outputs are all 0; (ii) they use mini-batch SGD, grad clipping compared to current implementation.

Very happy to keep iterating and hear your thoughts! case_study.html

fuyuan-li avatar Dec 14 '25 04:12 fuyuan-li

hey @fuyuan-li, nice findings, the goal code is in fact quitre different from the paper. I decided that it might be a good thing to work on both, jax and torch implementations together, that's why I came up with a temporary pr that should facilitate the testing process. Give it a look when you have time: https://github.com/google-deepmind/open_spiel/pull/1408. You can work on one version and give me the feedback about the other, for example.

Unfortunately, my harware is not good, so I can't do testing as you do. So, if you can help, we can fix the things faster. Maybe we should switch back to SGD, for real.

Also, if you decide to work on both implementations at the same time, notice that torch doesn't work with numpy>2 until, I think, 2.4 but jax has switched to it quite early. May be an issue for some python versions (like 3.12, tagging @lanctot, but it's an unlikely issue)

Also, I don't think that pytorch code should work on gpu straight away... Tell me what you think...

Also, there might be speed issues with the jax impl, but don't pay attn to them so far. I want to have working code first.

alexunderch avatar Dec 14 '25 16:12 alexunderch

also, might be a duplicate of https://github.com/google-deepmind/open_spiel/issues/1287

alexunderch avatar Dec 15 '25 14:12 alexunderch

Hey @alexunderch, thanks a lot for the detailed explanation and plan!— this is great, and I’m really glad we’re aligning both implementations and revisiting the paper.

I’m happy to take this, on both torch and jax sides. To keep things concrete, I can start from the torch implementation and use it as a reference point, while giving feedback / testing results on the jax side as we go.

Will take a closer look at #1408 and revert back with concrete findings / suggestions. Thanks again!

fuyuan-li avatar Dec 15 '25 16:12 fuyuan-li

Okay, I'll push the latest fixes for your experiments in the next couple hours

alexunderch avatar Dec 15 '25 16:12 alexunderch