jax
jax copied to clipboard
`ensure_compile_time_eval` does not error out for traced arrays
Description
The documentation for jax.ensure_compile_time_eval()
states that:
This context manager ensures that JAX computations are evaluated eagerly. If eager evaluation is not possible, a ConcretizationTypeError is raised.
The following snippet uses a traced array inside the ensure_compile_time_eval
context and executes without raising. The expected behaviour was a ConcretizationTypeError
error.
import jax
import jax.numpy as jnp
@jax.jit
def f(W):
with jax.ensure_compile_time_eval():
W = jnp.square(W.T)
return W
_ = jax.jit(f)(jnp.ones((1024, 1024)))
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.3
python: 3.11.7 (stable, redacted, redacted) [Clang
I think the best thing to do might be to change the documentation at this point. I definitely have code that relies on the current behaviour: do compile-time eval if possible, but defer to runtime if required.
If an expression can be computed at compile-time, it would definitely be computed at compile-time, right? I am assuming that computations on concrete arrays execute eagerly always and hence would behave like 'do at compile-time if possible, else defer' without having to enclose within ensure_compile_time_eval
block.
Nope, actually! When you are inside dynamic context of a jax.jit
, then all JAX operations are deferred to whatever the XLA compiler feels like doing -- which indeed might be to perform constant propagation and evaluate those things at compile time, but in practice there are several edge cases where this isn't the case. This context serves as a useful escape hatch when that happens.