jax icon indicating copy to clipboard operation
jax copied to clipboard

`jnp.arange` does not permit dynamic shaped arrays

Open josh146 opened this issue 1 year ago • 3 comments

Description

I've noticed that, when enabling JAX dynamic shape support via

jax.config.update("jax_dynamic_shapes", True)

jnp.arange (and similarly, jnp.linspace) both error if passed dynamic variables (which would generate a dynamic shaped array).

I'm wondering if this is a bug, as I noticed that in the source code, a non-concrete error is being raised in both the dynamic_shape=False and dynamic_shape=True case?

https://github.com/google/jax/blob/1e01fa7b0f1355c522f8420569dc778f2633c629/jax/_src/numpy/lax_numpy.py#L3143-L3154

Otherwise, if this is intentional, I can update this bug report to instead be a feature request.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.9 (main, Jan 11 2023, 09:18:18) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='MBA-HJXGDYDXH7', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:13:04 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6020', machine='arm64')

josh146 avatar Sep 03 '24 23:09 josh146

Note: for reference, my use case does not involve XLA (which does not support dynamic shaped arrays). instead, I am compiling the generated MHLO via LLVM.

josh146 avatar Sep 03 '24 23:09 josh146

dynamic shapes are still very experimental and don't have much support in JAX APIs. Assigning to @mattjj because he may know whether or not it's expected to work here.

jakevdp avatar Sep 03 '24 23:09 jakevdp

We haven't prioritized dynamic shapes work for a while, and so the only things available are bits from our past experiments. That said, it is often easy to make specific things work, and so I'm happy to hear specific feature requests like this (e.g. "make my jnp.arange call work with dynamic shapes"). (I'm calling it a feature request rather than a bug because the docs don't say this should work, i.e. this is "intentional" as you say.)

In this case, we only made the jnp.arange function work in its single-argument form:

import jax
jax.config.update('jax_dynamic_shapes', True)

jaxpr = jax.make_jaxpr(jnp.arange)(5)
print(jaxpr)
{ lambda ; a:i32[]. let
    b:i32[a] = iota[dimension=0 dtype=int32 shape=(None,)] a
  in (b,) }

In the code you linked, you can see that we only check that start is concrete if dynamic_shapes is False. But we always check that stop and step are None or concrete.

What signature of jnp.arange did you need?

mattjj avatar Sep 10 '24 13:09 mattjj

Thanks @mattjj, that is really helpful! I don't think we realized that jnp.arange(start) worked with dynamic shaped arrays.

For the moment, jnp.arange(start) should actually unblock us (since we can probably always transform it to include a step/offset).

But it would be great if jnp.arange supported step and stop also being dynamic variables, as we are often working with researchers who expect numpy semantics to 'just work' even with compilation happening under the hood.

josh146 avatar Oct 07 '24 18:10 josh146