jax
jax copied to clipboard
Extremely slow execution of CPU JAX on cluster vs. laptop (reraise_with_filtered_traceback?)
Description
I'm observing extremely slow execution on my cluster vs. my laptop. (For these, I'm using JAX on CPU only.)
Perfetto UI screenshots:
Laptop:
Cluster:
One major difference I spot between these charts is the prominent and repeated appearance of the reraise_with_filtered_traceback and scan bars on the cluster, but not the laptop.
Any idea what could be causing this slowdown, and how to fix it?
This issue might be related: #13407.
System info (python version, jaxlib version, accelerator, etc.)
Laptop:
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.12.2 (main, Feb 6 2024, 20:19:44) [Clang 15.0.0 (clang-1500.1.0.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Carloss-MacBook-Pro-2.local', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:54:05 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6031', machine='arm64')
Cluster:
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='marvel-0-15.eth', release='3.10.0-957.1.3.el7.x86_64', version='#1 SMP Thu Nov 29 14:49:43 UTC 2018', machine='x86_64')
Same here.