jax icon indicating copy to clipboard operation
jax copied to clipboard

Extremely slow execution of CPU JAX on cluster vs. laptop (reraise_with_filtered_traceback?)

Open carlosgmartin opened this issue 2 years ago • 1 comments

Description

I'm observing extremely slow execution on my cluster vs. my laptop. (For these, I'm using JAX on CPU only.)

Perfetto UI screenshots:

Laptop: image

Cluster: image

One major difference I spot between these charts is the prominent and repeated appearance of the reraise_with_filtered_traceback and scan bars on the cluster, but not the laptop.

Any idea what could be causing this slowdown, and how to fix it?

This issue might be related: #13407.

System info (python version, jaxlib version, accelerator, etc.)

Laptop:

jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.4
python: 3.12.2 (main, Feb  6 2024, 20:19:44) [Clang 15.0.0 (clang-1500.1.0.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Carloss-MacBook-Pro-2.local', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:54:05 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6031', machine='arm64')

Cluster:

jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.4
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='marvel-0-15.eth', release='3.10.0-957.1.3.el7.x86_64', version='#1 SMP Thu Nov 29 14:49:43 UTC 2018', machine='x86_64')

carlosgmartin avatar Mar 29 '24 22:03 carlosgmartin

Same here.

matteoguarrera avatar Jun 28 '24 19:06 matteoguarrera