JAX running in CPU only mode only uses a single core
When running JAX installed from pip on a CPU only host while monitoring core usage I only ever see a single core go to 100% utilization. All other cores are idle.
I have observed the same behavior on multiple separate machines with different python versions.
From threads https://github.com/google/jax/issues/743 and https://github.com/google/jax/issues/1539 I have attempted to use XLA_FLAGS="--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=16" but this makes no difference.
Observing the comment https://github.com/google/jax/issues/1539#issuecomment-578496962 about thread affinity I have also tried running my script prepended with taskset -c 0-15 with and without the above XLA_FLAGS directive but again this makes no difference.
There must be something I am missing here. All documentation and support threads here on Github imply that the default behavior is to detect and use all cores as a single local device using intra op parallelism, and yet I can only observe single threaded behavior in practice.
Any help would be much appreciated :)
Additionally does anyone have a link to documentation for all of the available XLA_FLAG options?
+1 for this, I've been having the same issue on several different CPUs (intel, amd, ibm) and several different versions of jax (0.1.7x - 0.2.5)
I am also having the same trouble. I can only get 1 cpu core utilized. Any help here would be much appreciated!
+1
Same here, on several different CPUs.
Ideally, I would like to have JAX listen to some environment variable along the lines of JAX_NUM_THREADS so users can balance internal multithreading with existing MPI implementations like mpi4jax. That would make JAX really compelling for more broad HPC use cases!
I am seeing the same behaviour on a large HPC application where it has become blocking as CPU behaviour matters as much as GPU behaviour for our targets.
+1 Seeing this across multiple machines, multiple operating systems. Even simple ops like jax.numpy.dot run on a single core.
Note that I explored further and discovered that, at least in my case, it seems linked to some worst case consequences of issue 5506.
It might be interesting to know how the small script I posted there behave for people in this issue.
+1 it would be great to get this resolved. JAX XLA compiled method i'm working on appears to have ~4-7x reduction in total cpu time compared to numpy, but on a multicore machine it's faster just to use numpy operations that use all cores.
Same issue with jax-0.3.14 on Intel i7-10875H, 8 physical cores, Linux 5.18.9
Using @nestordemeure's nice test script from #5506 shows a single core used:
numpy: cpu usage 1.0/16 wall_time:0.5s
vmap: cpu usage 1.0/16 wall_time:3.2s
xmap: cpu usage 1.0/16 wall_time:3.3s
dot: cpu usage 1.0/16 wall_time:20.4s
Similarly numpyro only heats a single core. Tested both cuda/CPU and CPU only jaxlib wheels, with various combinations of
XLA_FLAGS="--xla_force_host_platform_device_count=8"
taskset -c 0-7
numpyro.set_host_device_count(8)
etc, with no effect.
I have the same issue on a linux server.
lscpu gives 32, but jax only provides me with a device count of 1.
Is anyone working on this?
Same problem here. I have an embarrassingly parallel workload but it still uses only 1 core (~120%). Batch size doesn't seem to matter. I can't do the multiple XLA devices on CPU trick because of the memory duplication it would need. I fear I will have to switch away from Jax.
PS: I am hitting target performance for 1 core. So vectorization seems to be working properly.
A quick workaround might be to use the multiprocess mode.
This is largely working as intended at the moment. JAX doesn't parallelize operations across CPU cores unless you use explicit parallelism constructs like pmap. Some JAX operations (e.g., BLAS or LAPACK) operations have their own internal parallelism.
It's not out of the question we might allow for more implicit parallelism in the future.
A quick workaround might be to use the multiprocess mode.
The problem is that I can't shard my data across the processes easily and it's too large to duplicate. It's a graph algorithm, so I really need shared state random memory access to a giant array. Unless I misread the documentation, there is no way I can have a jax numpy array backed by the same data for all processes.
It's not out of the question we might allow for more implicit parallelism in the future.
That would be cool :) I'm switching to Numba for this part, but GPU portions are still JAX.
so I really need shared state random memory access to a giant array.
I think you are reading the docs right, IIUC JAX might not be the best tool for your use case, I would suggest numba, they provide a multithreaded model that I believe better fits your use case.
As a first time user, I was experimenting with summing two large arrays with numpy and jax in a kaggle notebook. Numpy is around 4 times faster than jax which is the number of threads. I thought I am doing somethin wrong and tried pmap, etc. I spent some time until reading this issue. If this is not going to be solved soon, a section in sharp bits could help first time users.
Is it still not possible to run on multiple cores without complex orchestration? My pennylane computation tasks using jax platform are really getting impacted by that slowdown... :(
+1 same problem here....
how to disable jax.jit multicore ?
+1 having the same issue here...
+1 having the same problem.