nla2021
nla2021 copied to clipboard
Lecture 1: `index_update` is deprecated in `jax`
jax.ops.index_update is deprecated, so in Summation section the line
x = jax.ops.index_update(x, [0], 1.)
needs to be changed to
x = x.at[0].set(1)