`jit`ing large models for inference has bad compilation performance
class Test(eqx.Module, Generic[Float]):
test: eqx.nn.Linear
def __init__(self, *, key: jax.Array, dtype: type[Float], in_features: int, out_features: int):
self.test = eqx.nn.Linear(
in_features=in_features, out_features=out_features, use_bias=False, key=key, dtype=dtype
)
def __call__(self, x: ndarray[Any, Float]) -> ndarray[Any, Float]:
return self.test(x)
for d in [1_000, 2_000, 4_000, 8_000, 16_000, 32_000, 64_000]:
t = Test(key=jax.random.PRNGKey(0), dtype=bfloat16, in_features=d, out_features=d)
print(d)
with jax.log_compiles():
eqx.filter_jit(t.__call__)(np.ones(d))
Finished tracing + transforming matmul for pjit in 0.0008780956268310547 sec
Finished tracing + transforming __call__ for pjit in 0.0028340816497802734 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[1000])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(__call__) in 0.05241560935974121 sec
1000
Finished XLA compilation of jit(__call__) in 1.2187151908874512 sec
Finished tracing + transforming matmul for pjit in 0.0005924701690673828 sec
Finished tracing + transforming __call__ for pjit in 0.002209901809692383 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[2000])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(__call__) in 0.01014566421508789 sec
2000
Finished XLA compilation of jit(__call__) in 1.414226770401001 sec
Finished tracing + transforming matmul for pjit in 0.0005934238433837891 sec
Finished tracing + transforming __call__ for pjit in 0.002273082733154297 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[4000])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(__call__) in 0.03254580497741699 sec
4000
Finished XLA compilation of jit(__call__) in 1.6822988986968994 sec
Finished tracing + transforming matmul for pjit in 0.0006303787231445312 sec
Finished tracing + transforming __call__ for pjit in 0.002046346664428711 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[8000])]. Argument mapping: [UnspecifiedValue].
Finished jaxpr to MLIR module conversion jit(__call__) in 0.11545777320861816 sec
8000
Finished XLA compilation of jit(__call__) in 3.7346856594085693 sec
Finished tracing + transforming matmul for pjit in 0.0006031990051269531 sec
Finished tracing + transforming __call__ for pjit in 0.0021157264709472656 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[16000])]. Argument mapping: [UnspecifiedValue].
16000
Finished jaxpr to MLIR module conversion jit(__call__) in 0.615095853805542 sec
Finished XLA compilation of jit(__call__) in 11.609867572784424 sec
Finished tracing + transforming matmul for pjit in 0.0005524158477783203 sec
Finished tracing + transforming __call__ for pjit in 0.0017201900482177734 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[32000])]. Argument mapping: [UnspecifiedValue].
32000
Finished jaxpr to MLIR module conversion jit(__call__) in 2.4231040477752686 sec
Finished XLA compilation of jit(__call__) in 40.67563080787659 sec
Finished tracing + transforming matmul for pjit in 0.0005431175231933594 sec
Finished tracing + transforming __call__ for pjit in 0.0017380714416503906 sec
Compiling __call__ with global shapes and types [ShapedArray(float32[64000])]. Argument mapping: [UnspecifiedValue].
64000
<crash here>
As you can see from the output, the jaxpr to MLIR and XLA compilation steps take longer and longer as the array dimension increases until it finally crashes during compilation. I believe this is because we're effectively closing over larger and larger values and JAX is doing work that scales with the size of the closed-over values (https://github.com/google/jax/issues/16278 may be related).
Flax avoids this issue because it directly passes the parameters/weights as arguments to the function. That perhaps seems like the best approach ATM. Is there a reasonable away to achieve behavior like that in Equinox?
(Unless I'm missing something, this is a pretty significant limitation for e.g. doing inference with language models where you'd want to JIT the sampling for a fixed model.)
jax: 0.4.30
jaxlib: 0.4.30
numpy: 1.26.1
python: 3.11.9 (main, Apr 6 2024, 17:59:24) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='nld09d8wce', release='5.19.0-45-generic', version='#46~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jun 7 15:06:04 UTC 20', machine='x86_64')
$ nvidia-smi
Tue Jun 25 00:52:43 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04 Driver Version: 525.116.04 CUDA Version: 12.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA RTX A6000 Off | 00000000:00:05.0 Off | Off |
| 30% 25C P2 16W [/](https://file+.vscode-resource.vscode-cdn.net/) 300W | 265MiB [/](https://file+.vscode-resource.vscode-cdn.net/) 49140MiB | 2% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
+-----------------------------------------------------------------------------+
Equinox version 0.11.4
Ah, I realized this is a workable solution:
def call(x: ndarray[Any, Float], dynamic, static) -> ndarray[Any, Float]:
test = eqx.combine(dynamic, static)
return test.__call__(x)
dynamic, static = eqx.partition(t, eqx.is_array)
jax.jit(ft.partial(call, static=static))(jnp.ones(d), dynamic)
I think this is happening because you're grabbing __call__, which as a magic method isn't subject to the same bound-methods-are-PyTrees treatment as regular methods. This is the reason t is being closed over, rather than provided as an input.
Can you try doing just eqx.filter_jit(t)(np.ones(d)) instead?
Ahh, that does make a big difference. I had gotten into the habit of doing explicit __call__ so jump-to-definition in my editor would be more useful and hadn't thought of it as anything more than a trivial syntactic transformation.