jax
jax copied to clipboard
segfault when using ensure_compile_time_eval
Description
Hi! I have a model that requires precomputing a large matrix that is a model constant. To avoid having to compute it at every function call, I thought I'd make it evaluate at compile time. However, this results in a segmentation fault. The same does not happen if I do not ask for compile-time evaluation. I'm not sure if this is a bug or some deeper issue I don't fully understand. This is related to the following question I posed on the Q&A https://github.com/google/jax/discussions/18830.
Here's a minimal (a bit absurdly minimal admittedly) reproducable example
import jax.numpy as jnp
import jax
model_data = {
'num_segments': 11544,
'segment_id': jnp.repeat(jnp.arange(11544), 30),
'val_id': jnp.tile(jnp.arange(10), 3 * 11544),
}
@jax.jit
def segfault_version(x):
with jax.ensure_compile_time_eval():
segment_id = model_data['segment_id']
val_id = model_data['val_id']
num_segments = model_data['num_segments']
agg_mat = (jnp.arange(num_segments).reshape(-1, 1) == segment_id.reshape(1, -1))
x = x[val_id]
return agg_mat @ x
x = jnp.ones(5)
print(segfault_version(x))
the output is simply a segmentation fault. I'm on a Mac Studio M2 (no jax metal, because too many things are broken)
What jax/jaxlib version are you using?
0.4.20 0.4.20
Which accelerator(s) are you using?
CPU
Additional system info?
1.26.2 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ] uname_result(system='Darwin', node='Bvatter-Studio-23', release='23.1.0', version='Darwin Kernel Version 23.1.0: Mon Oct 9 21:28:45 PDT 2023; root:xnu-10002.41.9~6/RELEASE_ARM64_T6020', machine='arm64')
NVIDIA GPU info
No response