jax icon indicating copy to clipboard operation
jax copied to clipboard

`ensure_compile_time_eval` does not error out for traced arrays

Open YashasSamaga opened this issue 4 months ago • 3 comments

Description

The documentation for jax.ensure_compile_time_eval() states that:

This context manager ensures that JAX computations are evaluated eagerly. If eager evaluation is not possible, a ConcretizationTypeError is raised.

The following snippet uses a traced array inside the ensure_compile_time_eval context and executes without raising. The expected behaviour was a ConcretizationTypeError error.

import jax
import jax.numpy as jnp

@jax.jit
def f(W):
  with jax.ensure_compile_time_eval():
    W = jnp.square(W.T)
  return W

_ = jax.jit(f)(jnp.ones((1024, 1024)))

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.26 jaxlib: 0.4.26 numpy: 1.26.3 python: 3.11.7 (stable, redacted, redacted) [Clang ] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Linux', , machine='x86_64')

YashasSamaga avatar Feb 28 '24 03:02 YashasSamaga

I think the best thing to do might be to change the documentation at this point. I definitely have code that relies on the current behaviour: do compile-time eval if possible, but defer to runtime if required.

patrick-kidger avatar Mar 02 '24 12:03 patrick-kidger

If an expression can be computed at compile-time, it would definitely be computed at compile-time, right? I am assuming that computations on concrete arrays execute eagerly always and hence would behave like 'do at compile-time if possible, else defer' without having to enclose within ensure_compile_time_eval block.

YashasSamaga avatar Mar 02 '24 14:03 YashasSamaga

Nope, actually! When you are inside dynamic context of a jax.jit, then all JAX operations are deferred to whatever the XLA compiler feels like doing -- which indeed might be to perform constant propagation and evaluate those things at compile time, but in practice there are several edge cases where this isn't the case. This context serves as a useful escape hatch when that happens.

patrick-kidger avatar Mar 02 '24 14:03 patrick-kidger