aesara
aesara copied to clipboard
Limited implementation of ARange Op in JAX
It looks like the jax.numpy version of arange doesn't take symbolic inputs (like reshape in #43). Currently, the test for that Op is marked as an expected failure.
Perhaps we can put together an implementation using jax.lax?
jax.numpy.arange and jax.numpy.reshape only accept concrete values respectively for their start, stop, step and shape parameters. Concrete values are either constants or the output of a Shape operator or combination of such outputs with Python operator. This is being refactored in #1338.