rajasekharporeddy
rajasekharporeddy
Hi @jbuckman @mattjj Looks like this issue is resolved in JAX version 0.4.26. I executed the mentioned code on colab (T4 GPU) with JAX version 0.4.26 without adding the `jax.effects_barrier()`...
Hi @sh0416 Thanks for the question. JAX leverages NumPy's implementation of `set_printoptions` directly by assigning it to `jax.numpy.set_printoptions`. https://github.com/google/jax/blob/c4dea624cce95e2dcd288831eadd68d1aae4a05a/jax/_src/numpy/lax_numpy.py#L117C1-L117C39 From https://github.com/numpy/numpy/issues/21653#issuecomment-1145751948, it can be understood that numpy displays the numbers...
Hi @Findus23 It looks the issue mentioned by you has been resolved. I tried to reproduce the issue mentioned by you on Colab with JAX version 0.4.23. Now the operations...
Hi @Justin-Tan Looks like this issue has been resolved in later versions of JAX. I executed the provided repro code with JAX version 0.4.23 on Google Colab using GPU run...
Hi @nouiz IIUC.. Looks like this issue has been resolved. I tried to reproduce the issue on Google colab (both CPU and GPU) with JAX version 0.4.23. The code executed...
Hi @mathisgerdes Looks like this issue has been resolved in later versions of JAX. I executed the mentioned code on colab (GPU T4) with cuda 12.3 and cuDNN 8.9.7 and...
Hi @mathisgerdes Please feel free to close the issue, if it is resolved. Thank you.
Hi @gileshd It looks like this issue has been resolved with PR #16018. I tried to execute the mentioned using JAX version 0.4.23 on colab by importing `tree_reduce` and `tree_map`...
Hi @mganahl It looks like this has been resolved. I executed the mentioned code with JAX version 0.4.23 on Colab CPU and JAX version 0.3.25 on colab TPU. Now the...
Hi @wbrenton I tested it on colab TPU v2 with JAX version 0.4.26 and JAX nightly version 0.4.31.dev20240630 and it works fine. Attaching [gist](https://colab.sandbox.google.com/gist/rajasekharporeddy/01d55a78d45ece94f4b3e6df1034cf08/tpu_15368.ipynb) for reference. Thank you.