Boris Yangel

Results 20 comments of Boris Yangel

Hello! A gentle reminder that this issue needs a fix or a proper workaround.

How can I distinguish between different optimization passes? Looks like these flags only give me HLO with and without all optimizations: ``` $ ls -l ./hlo_dump/ | grep decode_fn -rw-r--r--...

Nevermind, found it: `--xla_dump_hlo_pass_re=.*`

Turned out in my case the undesired behavior can be turned off by using `xla_gpu_enable_dot_strength_reduction=false`. It does come with a significant performance reduction though.

The problem as I understand it was that I was running inference with batch size 1, which meant that some of my matmuls were matrix-vector products. One of the XLA...

Hello! Running this MuZero implementation on an arbitrary gym environment is possible, but will require some code changes. First, you would need to change `make_env` function in `tools/agent.py` to instantiate...

Can confirm that sharding-based solution works if using `jax_default_prng_impl=rbg`

I think it might be the same issue as https://github.com/google/jax/issues/19893

@jakevdp Hey, sorry for mentioning you directly, but this issue hasn't received any attention for several weeks. Can someone from the jax team please take a look? Thanks!