catalyst
catalyst copied to clipboard
Indexed assignment doesn't work with dynamically-shaped arrays
The following program raises an error:
import jax.numpy as jnp
from catalyst import *
@qjit
def f(n: int, m: int):
x = jnp.ones((n, m), dtype=float)
y = jnp.ones((n, m), dtype=float)
@for_loop(0, n, 1, experimental_preserve_dimensions=True)
def sum_and_multiply(i, x, y):
x[i] = x[i] + y[i]
y[i] = x[i] * y[i]
return x, y
return sum_and_multiply(x, y)
f(2, 3)
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py:2072, in non_negative_dim(d)
[2070](jax/_src/core.py:2070) if is_constant_dim(d):
[2071](jax/_src/core.py:2071) return max(0, d)
-> [2072](jax/_src/core.py:2072) assert is_symbolic_dim(d)
[2073](jax/_src/core.py:2073) try:
[2074](jax/_src/core.py:2074) d_ge_0 = (d >= 0)
AssertionError:
It does not happen without the indexed assignment.