Deep cfr jax refactor
Hey! It's not the main PR, so it might be deleted later but to summarise:
- Replaced
Sonnet'sinitialisation with nativePytorch's. Implementations with jax and torch had different structures, now they're the same --- i.e. a commonMLPparametrisation andLayerNormafter base layers - Rewrote
jaximplementation usingflax.nnxto match the implementations of torch and tf - Deleted all
tfandtensorflow_datasetsstuff to simplify the code.
However,
torchandjaxdiffer with loss coefficients: intorch/tfit's $\sqrt{t}$, whereas in jax it is $\frac{t}{2T}$, not as the same but still differentjaximplementation used to use masking, however, as said in torch implementation, it's not as valuable because in traversal only legal actions are selectedjaximplementation is initialised with pytorch parameters to secure reproducibility
The goal is to make sure that both implementations converge to close values at the same time -- it'll allow to delete the tf implementation
At first glance, like your t/2T in Jax implementation! I think this is what original paper proposed.
Thank you! Will follow up more
@fuyuan-li , I found out that $\frac{t}{2T}$ is for the Linear CFR specifically, which shouldn't work with non-linear MLPs, so I decided to stick with $\sqrt{t}$ as in the original TF impl
Like, the only difference is masking used in the tf2 and jax implementations -- I am not really sure if it's good, but have kept it for consistency
@fuyuan-li, I added print_nash_convs argument for DeepCFRSolver to print exploitability of each optimisation step
can harm performance, but can help with tracking the progress
Quick update @alexunderch, the refactored torch impl cannot converge as the original impl, using the same hyper set as in #1287 (https://github.com/google-deepmind/open_spiel/issues/1287)
Will dig in and keep you posted!
@fuyuan-li, I erroneously had an odd relu in the network. Sorry! Now, both impls should converge.
@lanctot both versions now converge for kuhn, look at the https://github.com/google-deepmind/open_spiel/pull/1406
However, without additional code improvements (e.g. jitting the buffer which causes constant recompilation due to changing size thereof), the jax implementations is like 25 times slower than for pytorch. Maybe should tackle it as a separate issue.
When the results for the leduc are in, notify me if you want to merge, I will clean up some stuff.
@alexunderch probably a late update -- Yes confirmed too, both impl converged in kuhn poker! exploitability drop to 0.05, policy value converge to theoretical value (-0.06 for player0), tested with several random seeds (for both torch version and jax).
The only reason I didn't run a simulation over 30+ seeds is because jax is running super slow -- 70mins for 1 simulation, on kuhn poker.
Another update -- Leduc simulation (in pytorch impl) is running:
- on 1 result, exploitability drop to 0.6, policy value goes to -0.12 for player0.
Very happy to continue work on jax impl's performance issue too.(otherwise I doube a simulations for multiple seeds are feasible) Do you think it's a good time (i.e., working code is ready?) to start this new thread? Let me know!
@fuyuan-li , thank you for your testing. Good that the results are reproducible. The newest commit should lower the comp. time like twice (on my mac, at least). The further improvement, I think, should require some additional modifications for the ReservoirBuffer.
I reckon, that it's up to Dr. Lanctot, if he is okay with the slow but consistently looking implementation, or he thinks that we need to have compatible running times
We can have all the discussion in this thread, if you feel comfortable.
@fuyuan-li check it out now. jax version should be only 4-6 times slower than pytorch
It's partly because of https://github.com/jax-ml/jax/issues/16587 (at append_to_reservior function) and because I allocate the whole buffer right away
Hi guys, great work on this.. I'm super impressed to see the community collaboration here!
I reckon, that it's up to Dr. Lanctot, if he is okay with the slow but consistently looking implementation, or he thinks that we need to have compatible running times
No strong preferences here, I'm mostly happy to see that we can retain these implementations thanks to both of you working on this.
I'm ok with slightly incompatible running times if one of them is just faster / more efficient. How cryptic is it? Will still want it to be readable. So as long as you have enough comments explaining any non-obvious code, I think it's ok to have a more efficient version that is slightly inconsistent with the other one.
Pytorch implementation hasn't really changed in terms of readability.
In jax implementation, I replaced the buffer with a set of functions and made a jittable training loop. Should still stay readable.
if soon @fuyuan-li reports that their testing is fine, I can clean tf and reference implementations and add a couple of comments, and we should be good to go.
P.s. if we continue with refactoring, I will replace the networks and buffers with corresponding utility imports.
Thank you @alexunderch and @lanctot Quick comments for us:
- Convergence and consistency on kuhn in both pytorch and jax are confirmed.
- Convergence on Leduc in pytorch, based on 40 simulations: exploitability trained to 0.67, on average. Policy value (for player 0) arrived -0.14 (with std 0.01) across 40 simulations. (Thinking it's a convergence)
- Convergence on Leduc in jax is running. re @lanctot (How cryptic is it?): about 65 mins per simulation given the default paramters. (By "default hyper parameters", they are the hyper param sets confirmed convergence for pytorch and kuhn)
- @alexunderch : A small glitch: in your jax impl, the jax decorator for init_reservoir(), do you want to change from
@jax.jit(static_argnames=("capacity",))to@partial(jax.jit, static_argnames=("capacity",))? I updated it on my local to have the simulation run, but don't think I have to submit a PR on this -- defer to you to update in the branch HEAD to keep things simple. - Given the runnning time on Leduc in jax, expected to get results this weekends (it's running now), but feel free to go ahead if we are all happy about the kuhn's convergence result. I'll come back to log the simulation results when it's ready (regardless whether this PR is already merged).
@fuyuan-li thank you for your updates. Yes, I will update it locally. When everything is confirmed, we can merge.
Just for the sake of interest, can you compare GPU performances? Because as I mentioned, because of the array copying, jax performance may be worse than numpy. On paper, should be much faster. No code modifications needed, just install cuda versions...
Quick updates: @alexunderch Looks like for Leduc poker game, pytorch and jax impl don't converge. On a set of (the same) 25 seeds, (which is a subset of above 40 seeds) pytorch's impl resulted an average eploitability score 0.66, while jax gave 1.13. By any chance you may want to reproduce the experiment, here are the seeds I used and exploitability score I cached after each run. (Hyper params are as default)
policy_network_layers=(64,),
advantage_network_layers=(64,),
num_iterations=101,
num_traversals=375,
reinitialize_advantage_networks=True,
learning_rate=1e-3,
batch_size_advantage=256,
batch_size_strategy=256,
memory_capacity=1000000,
policy_network_train_steps=2500,
advantage_network_train_steps=375,
sim_results.json Let me know your thoughts.
@fuyuan-li, thank you, will back tomorrow with the fixed results. Let's say that 0.66 is what we're looking for. Thank you for the testing, couldn't have done it without you
My final results with the latest commit for leduc_poker and default hyperparameters:
- w/
pytorch
I1221 22:48:52.097683 8456724608 deep_cfr_pytorch.py:76] Final policy loss: 4.595390796661377
I1221 22:48:52.189984 8456724608 deep_cfr_pytorch.py:82] Deep CFR in leduc_poker - NashConv: 0.693684191700022
I1221 22:48:52.215540 8456724608 deep_cfr_pytorch.py:91] Computed player 0 value: -0.14
I1221 22:48:52.215598 8456724608 deep_cfr_pytorch.py:92] Computed player 1 value: 0.14
- w/
flax
I1222 01:51:04.677639 8456724608 deep_cfr_jax.py:76] Final policy loss: 1.7874398231506348
I1222 01:51:05.447016 8456724608 deep_cfr_jax.py:82] Deep CFR in leduc_poker - NashConv: 0.47039074715624
I1222 01:51:05.478743 8456724608 deep_cfr_jax.py:92] Computed player 0 value: -0.09
I1222 01:51:05.478972 8456724608 deep_cfr_jax.py:93] Computed player 1 value: 0.09
I think that they are quite close, don't know what I've changed since, sorry.
Thank you once more, @fuyuan-li. @lanctot, if you're fine with the code, it can now be merged. I found out that jax slowdowns mostly come from the replay memory. I moderately optimised it, however, don't see bigger room for the improvement so far. I'd say that efficient replay memory is the next step, for both, jax and torch.
Specifically, exploitablity for the jax version at each time step
NashConv @ 1 = 7.176929528083735 | Policy loss = None
NashConv @ 2 = 6.059597804504729 | Policy loss = 0.0026580938138067722
NashConv @ 3 = 4.05204699111699 | Policy loss = 0.040816642343997955
NashConv @ 4 = 3.4582866766464115 | Policy loss = 0.05231575295329094
NashConv @ 5 = 2.7447041333669633 | Policy loss = 0.07301361858844757
NashConv @ 6 = 2.0905215311051366 | Policy loss = 0.07009975612163544
NashConv @ 7 = 1.8484036125688335 | Policy loss = 0.08286646008491516
NashConv @ 8 = 1.7528570048692564 | Policy loss = 0.07966998219490051
NashConv @ 9 = 1.675157232381963 | Policy loss = 0.09868694841861725
NashConv @ 10 = 1.6924227446555857 | Policy loss = 0.16678816080093384
NashConv @ 11 = 1.5571462999466403 | Policy loss = 0.12358181923627853
NashConv @ 12 = 1.31621344313262 | Policy loss = 0.12477826327085495
NashConv @ 13 = 1.311415699359722 | Policy loss = 0.165201336145401
NashConv @ 14 = 1.249570358418782 | Policy loss = 0.18017971515655518
NashConv @ 15 = 1.145830898694582 | Policy loss = 0.23028936982154846
NashConv @ 16 = 1.2878010787582053 | Policy loss = 0.21455836296081543
NashConv @ 17 = 1.0336522741427303 | Policy loss = 0.19856809079647064
NashConv @ 18 = 0.9682536868153124 | Policy loss = 0.25924623012542725
NashConv @ 19 = 0.93623810843395 | Policy loss = 0.313173770904541
NashConv @ 20 = 0.849828143530562 | Policy loss = 0.28085824847221375
NashConv @ 21 = 0.8859937342419513 | Policy loss = 0.32723239064216614
NashConv @ 22 = 0.8829375033254149 | Policy loss = 0.3264906406402588
NashConv @ 23 = 0.7521178043661657 | Policy loss = 0.34091615676879883
NashConv @ 24 = 0.7626105944881375 | Policy loss = 0.37323853373527527
NashConv @ 25 = 0.9214537400011841 | Policy loss = 0.421398401260376
NashConv @ 26 = 0.7860146239756216 | Policy loss = 0.35111019015312195
NashConv @ 27 = 0.830563657479827 | Policy loss = 0.5253749489784241
NashConv @ 28 = 0.8740263555608351 | Policy loss = 0.44992372393608093
NashConv @ 29 = 0.8226630705413742 | Policy loss = 0.40230244398117065
NashConv @ 30 = 0.7982843048268536 | Policy loss = 0.4356726109981537
NashConv @ 31 = 0.7934473296428135 | Policy loss = 0.5469026565551758
NashConv @ 32 = 0.801172086376167 | Policy loss = 0.564400851726532
NashConv @ 33 = 0.6803939588046612 | Policy loss = 0.41089868545532227
NashConv @ 34 = 0.8693795656571353 | Policy loss = 0.4600907266139984
NashConv @ 35 = 0.7602445578842363 | Policy loss = 0.478179931640625
NashConv @ 36 = 0.8809427256569731 | Policy loss = 0.6630244255065918
NashConv @ 37 = 0.7707125211120526 | Policy loss = 0.43674033880233765
NashConv @ 38 = 0.8268448174525914 | Policy loss = 0.4923495054244995
NashConv @ 39 = 0.7583308000547349 | Policy loss = 0.5110579133033752
NashConv @ 40 = 0.7221252921735284 | Policy loss = 0.5919511318206787
NashConv @ 41 = 0.7893044355216443 | Policy loss = 0.5734891295433044
NashConv @ 42 = 0.8104256787314067 | Policy loss = 0.4891538619995117
NashConv @ 43 = 0.642608689898516 | Policy loss = 0.7405254244804382
NashConv @ 44 = 0.7279538834251799 | Policy loss = 0.716678261756897
NashConv @ 45 = 0.7183854687766063 | Policy loss = 0.7193781733512878
NashConv @ 46 = 0.7345949389975195 | Policy loss = 0.6201626658439636
NashConv @ 47 = 0.6685243737297071 | Policy loss = 0.5693798065185547
NashConv @ 48 = 0.6647768386686698 | Policy loss = 0.7520891427993774
NashConv @ 49 = 0.6416351106482392 | Policy loss = 0.7402688264846802
NashConv @ 50 = 0.6077754061790797 | Policy loss = 0.7785935401916504
NashConv @ 51 = 0.6438597397284539 | Policy loss = 0.8324920535087585
NashConv @ 52 = 0.696250522203123 | Policy loss = 0.8035646677017212
NashConv @ 53 = 0.6386856524114637 | Policy loss = 0.796439528465271
NashConv @ 54 = 0.624961073247444 | Policy loss = 0.6281327605247498
NashConv @ 55 = 0.6291509241876219 | Policy loss = 0.7602921724319458
NashConv @ 56 = 0.6066405379426694 | Policy loss = 0.8809983730316162
NashConv @ 57 = 0.6886520588118219 | Policy loss = 0.7726151943206787
NashConv @ 58 = 0.5729062846305042 | Policy loss = 0.8566088676452637
NashConv @ 59 = 0.5581417690589261 | Policy loss = 1.066760778427124
NashConv @ 60 = 0.6166203526994547 | Policy loss = 1.0109937191009521
NashConv @ 61 = 0.6045377538782075 | Policy loss = 0.8873990774154663
NashConv @ 62 = 0.6155058788687606 | Policy loss = 0.9575991630554199
NashConv @ 63 = 0.5814761545081377 | Policy loss = 0.9077833890914917
NashConv @ 64 = 0.6235539459516768 | Policy loss = 0.775898814201355
NashConv @ 65 = 0.5899268499148093 | Policy loss = 0.9223687052726746
NashConv @ 66 = 0.5913993869151155 | Policy loss = 0.9850955009460449
NashConv @ 67 = 0.5802760540063157 | Policy loss = 0.8636703491210938
NashConv @ 68 = 0.6808774213732282 | Policy loss = 0.9804381132125854
NashConv @ 69 = 0.6355856582392478 | Policy loss = 1.120728611946106
NashConv @ 70 = 0.5376710559386174 | Policy loss = 0.9073600769042969
NashConv @ 71 = 0.539856767610367 | Policy loss = 1.2031302452087402
NashConv @ 72 = 0.5639326686370347 | Policy loss = 0.9085475206375122
NashConv @ 73 = 0.5361189948755027 | Policy loss = 1.0187501907348633
NashConv @ 74 = 0.4998302567575119 | Policy loss = 1.0333696603775024
NashConv @ 75 = 0.5470106390725342 | Policy loss = 0.9921380877494812
NashConv @ 76 = 0.5477533590439871 | Policy loss = 1.114292860031128
NashConv @ 77 = 0.4748788827052619 | Policy loss = 1.1480873823165894
NashConv @ 78 = 0.4853924413111935 | Policy loss = 1.1924254894256592
NashConv @ 79 = 0.5551257667112115 | Policy loss = 1.3514446020126343
NashConv @ 80 = 0.5137493154633238 | Policy loss = 1.3134557008743286
NashConv @ 81 = 0.4726230282709643 | Policy loss = 1.015550136566162
NashConv @ 82 = 0.5037982490028234 | Policy loss = 1.3141393661499023
NashConv @ 83 = 0.46563420264355326 | Policy loss = 1.2847996950149536
NashConv @ 84 = 0.48749649254221794 | Policy loss = 1.2550193071365356
NashConv @ 85 = 0.4991398160372466 | Policy loss = 1.155588150024414
NashConv @ 86 = 0.4737138874514377 | Policy loss = 1.2591278553009033
NashConv @ 87 = 0.45924743564025483 | Policy loss = 1.177165150642395
NashConv @ 88 = 0.5759111933768976 | Policy loss = 1.3281080722808838
NashConv @ 89 = 0.4721997271043664 | Policy loss = 1.3457883596420288
NashConv @ 90 = 0.4780611930786257 | Policy loss = 1.6052868366241455
NashConv @ 91 = 0.42898678506625776 | Policy loss = 1.625423789024353
NashConv @ 92 = 0.5042931687374642 | Policy loss = 1.4520492553710938
NashConv @ 93 = 0.4920611608070192 | Policy loss = 1.299553632736206
NashConv @ 94 = 0.4816235766682106 | Policy loss = 1.526358962059021
NashConv @ 95 = 0.5420048979028065 | Policy loss = 1.4343509674072266
NashConv @ 96 = 0.5358024467849981 | Policy loss = 1.548895239830017
NashConv @ 97 = 0.5830988476624244 | Policy loss = 1.3108930587768555
NashConv @ 98 = 0.4521362535108876 | Policy loss = 1.1970348358154297
NashConv @ 99 = 0.5270631176306322 | Policy loss = 1.2043466567993164
NashConv @ 100 = 0.4157276186322467 | Policy loss = 1.3255438804626465
NashConv @ 101 = 0.5831885321318206 | Policy loss = 1.37606942653656
<\details>
@alexunderch Thanks for the update! I realized I was about 5 commits behind your latest HEAD(508192f 2 days ago when I first started the simulation. Just now rebased onto your current HEAD and reran the experiments.. I can confirm a very similar per-step convergence behavior in JAX (using the default seed, I believe 42), reaching ~0.62 at step 101—though not exactly the same decimal value. I also double-checked the results with the previous seed (8150) and confirmed that exploitability converges to around 0.53. Thanks for this refactoring! (by the time I wrote this, I see one more commit, but seems mainly cleaning code -- shouldn't impact the results?)
Yes, I just trimmed a couple of lines, it doesnt affect the results
@lanctot I am sorry I checked the tests they didn't fail but I forgot(?) to rerun before the commit. Also, I am not sure that they will pass on python3.10 because it uses jax and flax from 3 years go. There was no nnx api back then, or was it?