einx
einx copied to clipboard
Example einx.add_at does not work
I am trying to get the example:
einx.add_at("b ([h w]) c, ([2] b) i, c i -> c [h w] b", image, coordinates, updates)
to work. Could you please provide array definitions that make the example work without a stack trace? I am using Python 3.10 . So far, I have been unsuccessful. The example can be found in the tutorial on operators .
This works on my end:
import jax.numpy as jnp
import einx
b = 4
h = 80
w = 128
c = 3
i = 10
image = jnp.zeros((b, h * w, c))
coordinates = jnp.zeros((2 * b, i), dtype="int32")
updates = jnp.zeros((c, i))
image = einx.add_at("b ([h w]) c, ([2] b) i, c i -> c [h w] b", image, coordinates, updates, h=h)
The example was indeed missing an additional parameter, since h and w were not fully constrained.