jax icon indicating copy to clipboard operation
jax copied to clipboard

JAX running in CPU only mode only uses a single core

Open JossWhittle opened this issue 5 years ago • 16 comments

When running JAX installed from pip on a CPU only host while monitoring core usage I only ever see a single core go to 100% utilization. All other cores are idle.

I have observed the same behavior on multiple separate machines with different python versions.

From threads https://github.com/google/jax/issues/743 and https://github.com/google/jax/issues/1539 I have attempted to use XLA_FLAGS="--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=16" but this makes no difference.

Observing the comment https://github.com/google/jax/issues/1539#issuecomment-578496962 about thread affinity I have also tried running my script prepended with taskset -c 0-15 with and without the above XLA_FLAGS directive but again this makes no difference.

There must be something I am missing here. All documentation and support threads here on Github imply that the default behavior is to detect and use all cores as a single local device using intra op parallelism, and yet I can only observe single threaded behavior in practice.

Any help would be much appreciated :)

Additionally does anyone have a link to documentation for all of the available XLA_FLAG options?

JossWhittle avatar Nov 26 '20 22:11 JossWhittle

+1 for this, I've been having the same issue on several different CPUs (intel, amd, ibm) and several different versions of jax (0.1.7x - 0.2.5)

f0uriest avatar Jan 11 '21 00:01 f0uriest

I am also having the same trouble. I can only get 1 cpu core utilized. Any help here would be much appreciated!

mrtupek avatar Feb 22 '21 22:02 mrtupek

+1

Same here, on several different CPUs.

Ideally, I would like to have JAX listen to some environment variable along the lines of JAX_NUM_THREADS so users can balance internal multithreading with existing MPI implementations like mpi4jax. That would make JAX really compelling for more broad HPC use cases!

Matematija avatar Nov 24 '21 17:11 Matematija

I am seeing the same behaviour on a large HPC application where it has become blocking as CPU behaviour matters as much as GPU behaviour for our targets.

nestordemeure avatar Jan 07 '22 06:01 nestordemeure

+1 Seeing this across multiple machines, multiple operating systems. Even simple ops like jax.numpy.dot run on a single core.

rborder avatar May 27 '22 01:05 rborder

Note that I explored further and discovered that, at least in my case, it seems linked to some worst case consequences of issue 5506.

It might be interesting to know how the small script I posted there behave for people in this issue.

nestordemeure avatar May 27 '22 04:05 nestordemeure

+1 it would be great to get this resolved. JAX XLA compiled method i'm working on appears to have ~4-7x reduction in total cpu time compared to numpy, but on a multicore machine it's faster just to use numpy operations that use all cores.

oliverpriebe avatar Jun 30 '22 19:06 oliverpriebe

Same issue with jax-0.3.14 on Intel i7-10875H, 8 physical cores, Linux 5.18.9

Using @nestordemeure's nice test script from #5506 shows a single core used:

numpy: cpu usage 1.0/16 wall_time:0.5s
vmap: cpu usage 1.0/16 wall_time:3.2s
xmap: cpu usage 1.0/16 wall_time:3.3s
dot: cpu usage 1.0/16 wall_time:20.4s

Similarly numpyro only heats a single core. Tested both cuda/CPU and CPU only jaxlib wheels, with various combinations of XLA_FLAGS="--xla_force_host_platform_device_count=8" taskset -c 0-7 numpyro.set_host_device_count(8) etc, with no effect.

mattja avatar Jul 08 '22 06:07 mattja

I have the same issue on a linux server. lscpu gives 32, but jax only provides me with a device count of 1. Is anyone working on this?

AlexanderFengler avatar Aug 03 '22 20:08 AlexanderFengler

Same problem here. I have an embarrassingly parallel workload but it still uses only 1 core (~120%). Batch size doesn't seem to matter. I can't do the multiple XLA devices on CPU trick because of the memory duplication it would need. I fear I will have to switch away from Jax.

PS: I am hitting target performance for 1 core. So vectorization seems to be working properly.

shlapfish avatar Aug 22 '22 12:08 shlapfish

A quick workaround might be to use the multiprocess mode.

mjsML avatar Aug 22 '22 12:08 mjsML

This is largely working as intended at the moment. JAX doesn't parallelize operations across CPU cores unless you use explicit parallelism constructs like pmap. Some JAX operations (e.g., BLAS or LAPACK) operations have their own internal parallelism.

It's not out of the question we might allow for more implicit parallelism in the future.

hawkinsp avatar Aug 22 '22 13:08 hawkinsp

A quick workaround might be to use the multiprocess mode.

The problem is that I can't shard my data across the processes easily and it's too large to duplicate. It's a graph algorithm, so I really need shared state random memory access to a giant array. Unless I misread the documentation, there is no way I can have a jax numpy array backed by the same data for all processes.

It's not out of the question we might allow for more implicit parallelism in the future.

That would be cool :) I'm switching to Numba for this part, but GPU portions are still JAX.

shlapfish avatar Aug 22 '22 13:08 shlapfish

so I really need shared state random memory access to a giant array.

I think you are reading the docs right, IIUC JAX might not be the best tool for your use case, I would suggest numba, they provide a multithreaded model that I believe better fits your use case.

mjsML avatar Aug 22 '22 13:08 mjsML

As a first time user, I was experimenting with summing two large arrays with numpy and jax in a kaggle notebook. Numpy is around 4 times faster than jax which is the number of threads. I thought I am doing somethin wrong and tried pmap, etc. I spent some time until reading this issue. If this is not going to be solved soon, a section in sharp bits could help first time users.

deadsoul44 avatar Oct 13 '22 17:10 deadsoul44

Is it still not possible to run on multiple cores without complex orchestration? My pennylane computation tasks using jax platform are really getting impacted by that slowdown... :(

ttrenty avatar Jun 24 '24 15:06 ttrenty

+1 same problem here....

SaschaFroelich avatar Aug 05 '24 07:08 SaschaFroelich

how to disable jax.jit multicore ?

iperov avatar Aug 06 '24 07:08 iperov

+1 having the same issue here...

slowlightx avatar Nov 15 '24 23:11 slowlightx

+1 having the same problem.

chen-zz20 avatar Dec 10 '24 06:12 chen-zz20