jax icon indicating copy to clipboard operation
jax copied to clipboard

segfault when using ensure_compile_time_eval

Open benjaminvatterj opened this issue 1 year ago • 2 comments

Description

Hi! I have a model that requires precomputing a large matrix that is a model constant. To avoid having to compute it at every function call, I thought I'd make it evaluate at compile time. However, this results in a segmentation fault. The same does not happen if I do not ask for compile-time evaluation. I'm not sure if this is a bug or some deeper issue I don't fully understand. This is related to the following question I posed on the Q&A https://github.com/google/jax/discussions/18830.

Here's a minimal (a bit absurdly minimal admittedly) reproducable example

import jax.numpy as jnp
import jax

model_data = {
    'num_segments': 11544,
    'segment_id': jnp.repeat(jnp.arange(11544), 30),
    'val_id': jnp.tile(jnp.arange(10), 3 * 11544),
}


@jax.jit
def segfault_version(x):
    with jax.ensure_compile_time_eval():
        segment_id = model_data['segment_id']
        val_id = model_data['val_id']
        num_segments = model_data['num_segments']
        agg_mat = (jnp.arange(num_segments).reshape(-1, 1) == segment_id.reshape(1, -1))
    x = x[val_id]
    return agg_mat @ x

x = jnp.ones(5)
print(segfault_version(x))

the output is simply a segmentation fault. I'm on a Mac Studio M2 (no jax metal, because too many things are broken)

What jax/jaxlib version are you using?

0.4.20 0.4.20

Which accelerator(s) are you using?

CPU

Additional system info?

1.26.2 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ] uname_result(system='Darwin', node='Bvatter-Studio-23', release='23.1.0', version='Darwin Kernel Version 23.1.0: Mon Oct 9 21:28:45 PDT 2023; root:xnu-10002.41.9~6/RELEASE_ARM64_T6020', machine='arm64')

NVIDIA GPU info

No response

benjaminvatterj avatar Dec 05 '23 21:12 benjaminvatterj