jax
jax copied to clipboard
Add jet rules for `scatter_mul` and `copy`, and fix typos in `scatter*` docstrings.
Add jet rules and tests for scatter_mul
and copy
, and fix typos in scatter*
docstrings.
The jet rule for scatter_mul
takes inspiration from its jvp
rule.