Jake Vanderplas
Jake Vanderplas
Interesting idea - if I understand your suggestion correctly, I think the implementations would look like this: ```python def iterate(n, f, x): return lax.scan(lambda x, _: (f(x), x), x, None,...
A further question for discussion:are these new API's necessary if they can be implemented via a single call to `scan`? One might argue that `iterate` and `orbit` already exist in...
Sure, but adding new APIs does not come without maintenance costs. It's true that `while_loop` and `fori_loop` can be implemented in terms of `scan`, but their implementations are far more...
`while_loop` doesn't lower to `scan`, actually, since `scan` requires a static number of iterations.
And `orbit_while` is not currently possible to express in JIT-compatible JAX, because it returns an array of dynamic length.
I think that could be an improvement – I'd want to hear opinions from other folks on the team
JAX is only compliant with the array api drype semantics when `jax_enable_x64` is set to true. Any testing would have to take that into account.
> Thanks Jake! So for completeness, the stanza to locally run a test from the test suite is > > ``` > $ JAX_ENABLE_FLOAT64=True ARRAY_API_TESTS_VERSION="2024.12" ARRAY_API_TESTS_MODULE=jax.numpy pytest path/to/test > ```...
xref #15358 for the unorderable keys issue.
Thanks for the contribution! You'll need to sign the CLA before we can take a look in detail, but a couple broad points: - this PR combines logic changes with...