Carlos Martin
Carlos Martin
@fabianp @vroulet I've suggested some related changes at #1110.
The functions in [`optax.projections`](https://optax.readthedocs.io/en/latest/api/projections.html#available-projections) operate on entire general pytrees, not arrays. How would axes be defined in that case? Perhaps we need to create versions of these functions that operate...
@fabianp Done.
@fabianp 
@Transurgeon Converting the JAX array to a NumPy array also works, and is what I'm doing right now. I was wondering if it might be more efficient to avoid that...
@Transurgeon [jax.pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html) allows a JIT-compiled JAX program to call a pure Python function. Here's a minimal example: ```python3 import cvxpy as cp import jax import numpy as np from jax...
@fabianp Done.
@fabianp 
@jakevdp @mattjj Looks right: ```python from jax.numpy import array from jax.lax import scan def iterate_1(n, f, x): for _ in range(n): x = f(x) return x def iterate_2(n, f, x):...
Out of curiosity, can this function be implemented in terms of lax primitives? ```python from jax.numpy import array # orbit_while : ∀a. (a → 𝔹) → (a → a) →...