xla icon indicating copy to clipboard operation
xla copied to clipboard

[PJRT:GPU] Add setting for mocked number of hosts per slice

Open jaro-sevcik opened this issue 1 year ago • 7 comments

With the existing enable_mock_nccl setting it is impossible to warm up compilation cache when there are multiple processes per node. This is because the cache key includes topology and GPU topology contains information about number of slices and number of hosts per slice. The current mocking of topologies always sets num_hosts_per_slice to 1. However, if you have multiple GPUs on a node and run a process-per-GPU then num_hosts_per_slice must be set to the number of GPUs.

This patch allows setting num_hosts_per_slice explicitly when creating the GPU client.

jaro-sevcik avatar Jul 31 '24 17:07 jaro-sevcik

Looks reasonable to me, but I think @hawkinsp should LGTM as well.

cheshire avatar Aug 02 '24 11:08 cheshire

@hawkinsp Could you please help take a look?

penpornk avatar Aug 06 '24 09:08 penpornk

@hawkinsp A friendly nudge to take a look, your review has been requested for this PR.

sgerrard avatar Aug 12 '24 22:08 sgerrard

@cheshire I have updated the patch to accept a parameter of the form mock_gpu_topology=4x2. Is this what you had in mind?

jaro-sevcik avatar Aug 23 '24 13:08 jaro-sevcik

Are you asking how it's going to be exposed in JAX? Here is a draft PR: https://github.com/google/jax/pull/23374

Or were you asking for some JAX scripts that make use of that option?

jaro-sevcik avatar Sep 02 '24 14:09 jaro-sevcik

@hawkinsp Could you please take a look at this PR?

dimitar-asenov avatar Sep 16 '24 14:09 dimitar-asenov

@cheshire Could you take another look, please?

jaro-sevcik avatar Sep 27 '24 08:09 jaro-sevcik