jax
jax copied to clipboard
Allowing pmap on CPUs by making CPU cores visibles directly on Jax.
XLA does detect the number of CPU cores automatically, but by default it presents all the cores as one device and uses the cores for its own intra-op parallelism
XLA_FLAGS="--xla_force_host_platform_device_count=8" tells XLA to split up the cores so that you can manually do parallel programming with them (e.g. using pmap)
that might be a good idea e.g. if you have embarrassingly parallel work to do, like running several MCMC chains in parallel, which might be a better use of parallelism than the intra-op XLA parallelism in some cases
I just recognize that we can set xla_force_host_platform_device_count=200
in a 12-core processor and still get the desired result. Does that way have any side effect?
cc @skye
I filed an internal XLA issue requesting that they expose the device count setting programatically, so we can create a JAX API on top of that.
We'll still need to set the device count when the CPU backend is initialized, i.e. right before any jax operations are run or earlier. Once it's initialized, I don't think you'll be able to change it unfortunately. I'm imagining an API like jax.set_cpu_device_count(n)
that should be called right after importing jax, what do you think?
an API like jax.set_cpu_device_count(n) that should be called right after importing jax
Thank @skye, I love this idea! Currently, we have to make a utility set_host_devices for this purpose but we have to include a big warning that we don't understand much the effect of the xla_force_host_platform_device_count
flag. Having an API call directly from JAX would be very helpful!
cc @neerajprad
@skye having a jax function to do this would be great, even if it has to happen right at the beginning. I can even warn the user if request more devices that there are physical available.
@skye - I thought I'll circle back on this after our brief discussion at Neurips. Our motivation for setting device count is to be able to use pmap
for running MCMC chains in parallel. So setting jax.set_cpu_device_count(n)
with a sufficiently high value for n
after importing JAX will work for our use cases (number of parallel chains will be capped at n
), but I wonder if there are any trade-offs or unanticipated effects of initializing this across all platforms.
@skye What is the status on this?
@gnecula Unstarted. We still need to do the XLA plumbing to allow setting the --xla_force_host_platform_device_count
flag programmatically.
@neerajprad Sorry for not replying earlier! What do you mean by initializing across all platforms? (Specifically, what is a platform in this case?)
I have a simpler suggestion: allow the replica -> device mapping to be non one-to-one. i.e., allow two replicas to be mapped to the same device. I think very little work is needed; mostly I would imagine we'd need to allow some larger threadpools and to change the implementation of CPU collectives to use the replica number rather than a thread number when identifying participants.
I don't think that would be simpler to implement. Beyond the XLA changes you mentioned, we'd also need to teach pmap that the CPU backend doesn't use multiple devices, and there may be further side effects of changing the programming model that we don't realize yet. It would be simpler to use in the end though. If you want to go down this path I won't stop you, but I personally would start with plumbing the CPU device count up to the LocalClient constructor (maybe this will be hard for some reason as well?).
@neerajprad Sorry for not replying earlier! What do you mean by initializing across all platforms? (Specifically, what is a platform in this case?)
Thanks for looking into this, @skye. For us to be able to run parallel MCMC chains, we would need to initialize this to some high number (say, 1000) beforehand for both CPU and GPU (that's what I meant by platform), regardless of the number of cores or GPUs the user might have. We weren't sure how safe that would be and if this would cause us to miss any existing or planned XLA optimizations. I think we did try and didn't notice any slowdown on CPU but wanted to confirm this. cc @fehiepsi
Ah ok, I didn't realize you'd also like multiple threads on each GPU (sorry if I forgot from NeurIPS). I don't think XLA has any interface for this currently, and I'm not sure how hard it would be to implement. @hawkinsp may know more.
@skye pytorch xla frontend wrote their own distributed multiprocessing implementation, does not seem to depend on "XLA plumbing ":
https://github.com/pytorch/xla/tree/master/torch_xla/distributed
Just chiming in - this would be useful even if doesn't do multiple threads per GPU. I'm using this right now to accelerate CPU-only inference (perhaps there's a better way?). With the jax ResNet from big_transfer on 224x224 images, it maxes out at using roughly 5 cores at batchsize 4 and doesn't increase with larger batches. Using 8-way pmap and manually setting the XLA_FLAGS to specify 8 devices speeds it up linearly.
(Obviously, this could also be done by supercharging vmap for CPU as noted in #5506, which would be preferable, but this seems like a shorter path to getting something working quickly.)
What is the status on this and how should using the XLA_FLAGS="--xla_force_host_platform_device_count=200" affect the behaviour?
Looking to test pmap on multiple cpus.
Bumping this for hopefully more visibility!
+1. This is also potentially very useful for testing that pmap'ed functions act as expected in CPU-based test environments.
Hi, are there any updates on this? What is best practice for using multiple cores with JAX?