catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Indexed assignment doesn't work with dynamically-shaped arrays

Open dime10 opened this issue 1 year ago • 3 comments

The following program raises an error:

import jax.numpy as jnp
from catalyst import *

@qjit
def f(n: int, m: int):
    x = jnp.ones((n, m), dtype=float)
    y = jnp.ones((n, m), dtype=float)

    @for_loop(0, n, 1, experimental_preserve_dimensions=True)
    def sum_and_multiply(i, x, y):
        x[i] = x[i] + y[i]
        y[i] = x[i] * y[i]
        return x, y

    return sum_and_multiply(x, y)

f(2, 3)
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py:2072, in non_negative_dim(d)
   [2070](jax/_src/core.py:2070) if is_constant_dim(d):
   [2071](jax/_src/core.py:2071)   return max(0, d)
-> [2072](jax/_src/core.py:2072) assert is_symbolic_dim(d)
   [2073](jax/_src/core.py:2073) try:
   [2074](jax/_src/core.py:2074)   d_ge_0 = (d >= 0)

AssertionError:

It does not happen without the indexed assignment.

dime10 avatar Jul 05 '24 15:07 dime10

Don't you need to use x.at[i].set(x[i] + y[i]) here?

josh146 avatar Jul 05 '24 18:07 josh146

Good point 😅 but the error happens before that issue kicks in apparently

dime10 avatar Jul 05 '24 18:07 dime10

ah, so it happens at the 'get value' stage, got it!

josh146 avatar Jul 05 '24 19:07 josh146