jax icon indicating copy to clipboard operation
jax copied to clipboard

Allowing pmap on CPUs by making CPU cores visibles directly on Jax.

Open sguada opened this issue 5 years ago • 16 comments

XLA does detect the number of CPU cores automatically, but by default it presents all the cores as one device and uses the cores for its own intra-op parallelism

XLA_FLAGS="--xla_force_host_platform_device_count=8" tells XLA to split up the cores so that you can manually do parallel programming with them (e.g. using pmap)

that might be a good idea e.g. if you have embarrassingly parallel work to do, like running several MCMC chains in parallel, which might be a better use of parallelism than the intra-op XLA parallelism in some cases

sguada avatar Sep 27 '19 18:09 sguada

I just recognize that we can set xla_force_host_platform_device_count=200 in a 12-core processor and still get the desired result. Does that way have any side effect?

cc @skye

fehiepsi avatar Sep 28 '19 06:09 fehiepsi

I filed an internal XLA issue requesting that they expose the device count setting programatically, so we can create a JAX API on top of that.

We'll still need to set the device count when the CPU backend is initialized, i.e. right before any jax operations are run or earlier. Once it's initialized, I don't think you'll be able to change it unfortunately. I'm imagining an API like jax.set_cpu_device_count(n) that should be called right after importing jax, what do you think?

skye avatar Oct 01 '19 00:10 skye

an API like jax.set_cpu_device_count(n) that should be called right after importing jax

Thank @skye, I love this idea! Currently, we have to make a utility set_host_devices for this purpose but we have to include a big warning that we don't understand much the effect of the xla_force_host_platform_device_count flag. Having an API call directly from JAX would be very helpful!

cc @neerajprad

fehiepsi avatar Oct 01 '19 00:10 fehiepsi

@skye having a jax function to do this would be great, even if it has to happen right at the beginning. I can even warn the user if request more devices that there are physical available.

sguada avatar Oct 01 '19 21:10 sguada

@skye - I thought I'll circle back on this after our brief discussion at Neurips. Our motivation for setting device count is to be able to use pmap for running MCMC chains in parallel. So setting jax.set_cpu_device_count(n) with a sufficiently high value for n after importing JAX will work for our use cases (number of parallel chains will be capped at n), but I wonder if there are any trade-offs or unanticipated effects of initializing this across all platforms.

neerajprad avatar Jan 09 '20 21:01 neerajprad

@skye What is the status on this?

gnecula avatar Apr 27 '20 07:04 gnecula

@gnecula Unstarted. We still need to do the XLA plumbing to allow setting the --xla_force_host_platform_device_count flag programmatically.

@neerajprad Sorry for not replying earlier! What do you mean by initializing across all platforms? (Specifically, what is a platform in this case?)

skye avatar Apr 27 '20 17:04 skye

I have a simpler suggestion: allow the replica -> device mapping to be non one-to-one. i.e., allow two replicas to be mapped to the same device. I think very little work is needed; mostly I would imagine we'd need to allow some larger threadpools and to change the implementation of CPU collectives to use the replica number rather than a thread number when identifying participants.

hawkinsp avatar Apr 27 '20 18:04 hawkinsp

I don't think that would be simpler to implement. Beyond the XLA changes you mentioned, we'd also need to teach pmap that the CPU backend doesn't use multiple devices, and there may be further side effects of changing the programming model that we don't realize yet. It would be simpler to use in the end though. If you want to go down this path I won't stop you, but I personally would start with plumbing the CPU device count up to the LocalClient constructor (maybe this will be hard for some reason as well?).

skye avatar Apr 27 '20 18:04 skye

@neerajprad Sorry for not replying earlier! What do you mean by initializing across all platforms? (Specifically, what is a platform in this case?)

Thanks for looking into this, @skye. For us to be able to run parallel MCMC chains, we would need to initialize this to some high number (say, 1000) beforehand for both CPU and GPU (that's what I meant by platform), regardless of the number of cores or GPUs the user might have. We weren't sure how safe that would be and if this would cause us to miss any existing or planned XLA optimizations. I think we did try and didn't notice any slowdown on CPU but wanted to confirm this. cc @fehiepsi

neerajprad avatar Apr 28 '20 03:04 neerajprad

Ah ok, I didn't realize you'd also like multiple threads on each GPU (sorry if I forgot from NeurIPS). I don't think XLA has any interface for this currently, and I'm not sure how hard it would be to implement. @hawkinsp may know more.

skye avatar Apr 29 '20 16:04 skye

@skye pytorch xla frontend wrote their own distributed multiprocessing implementation, does not seem to depend on "XLA plumbing ":

https://github.com/pytorch/xla/tree/master/torch_xla/distributed

den-run-ai avatar Dec 30 '20 22:12 den-run-ai

Just chiming in - this would be useful even if doesn't do multiple threads per GPU. I'm using this right now to accelerate CPU-only inference (perhaps there's a better way?). With the jax ResNet from big_transfer on 224x224 images, it maxes out at using roughly 5 cores at batchsize 4 and doesn't increase with larger batches. Using 8-way pmap and manually setting the XLA_FLAGS to specify 8 devices speeds it up linearly.

(Obviously, this could also be done by supercharging vmap for CPU as noted in #5506, which would be preferable, but this seems like a shorter path to getting something working quickly.)

dave-andersen avatar Mar 18 '21 15:03 dave-andersen

What is the status on this and how should using the XLA_FLAGS="--xla_force_host_platform_device_count=200" affect the behaviour?

Looking to test pmap on multiple cpus.

xmax1 avatar May 11 '21 05:05 xmax1

Bumping this for hopefully more visibility!

proteneer avatar Oct 29 '21 02:10 proteneer

+1. This is also potentially very useful for testing that pmap'ed functions act as expected in CPU-based test environments.

sdenton4 avatar May 18 '22 20:05 sdenton4

Hi, are there any updates on this? What is best practice for using multiple cores with JAX?

pkairys avatar Jul 17 '23 15:07 pkairys