lecture-jax
lecture-jax copied to clipboard
In place operations in JAX
@mmcky
Can you please search for this line:
In line with immutability, JAX does not support inplace operations:
It looks like JAX is sorting in place now. Could you please evaluate and propose a fix?
https://github.com/QuantEcon/lecture-jax/blob/1a1be32fd03b11a4bfd960565818cddfa2070176/lectures/jax_intro.md?plain=1#L161 Its here.
For me, its still not inplace:
>>> import jax
>>> jax.__version__
'0.4.23'
>>> import jax.numpy as jnp
>>> a = jnp.array((2, 1))
>>> a.sort()
Array([1, 2], dtype=int32)
>>> a
Array([2, 1], dtype=int32)