jax
jax copied to clipboard
Bogus gradient value from value_and_grad for committed DeviceArray on a multi-GPU host
Hi,
I'm reporting this issue with a disclaimer that it's probably not reproducible (sorry!) but hoping it's of some worth for the awesome JAX developers:
We have a machine with 5 NVIDIA A100 (80GB) cards, and we see that the gradient computed by value_and_grad is wrong when DeviceArray committed two specific devices are passed as the function argument. Here is the simplest example:
import jax
fn = jax.value_and_grad(lambda x: x * 1)
for i in range(jax.device_count()):
x = jax.device_put(1., device=jax.devices()[i])
v, g = fn(x)
print(v, g)
Printout:
1.0 1.0
1.0 1.0
1.0 1e-45
1.0 1e-45
1.0 1.0
When the function is lambda x: x or lambda x: x * 1. (note the dot after 1 - multiplying by a float here) the result from all devices are correct (1.0 1.0).
Also, if the call to fn is put in the default_device context, as in:
import jax
fn = jax.value_and_grad(lambda x: x * 1)
for i in range(jax.device_count()):
x = jax.device_put(1., device=jax.devices()[i])
with jax.default_device(jax.devices()[i]):
v, g = fn(x)
print(v, g)
the output gradient is correct for all devices. Which makes me suspect that uncommitted arrays are involved somewhere in the gradient calculation, even when the input is committed. And also that's only in some specific cases.
The only information that may be even remotely relevant that I can think of is the device topology:
$ nvidia-smi topo -m
GPU0 GPU1 GPU2 GPU3 GPU4 CPU Affinity NUMA Affinity
GPU0 X NV12 PXB PXB SYS 0-31 0
GPU1 NV12 X PXB PXB SYS 0-31 0
GPU2 PXB PXB X NV12 SYS 0-31 0
GPU3 PXB PXB NV12 X SYS 0-31 0
GPU4 SYS SYS SYS SYS X 32-63 1
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
That is, device pairs 0&1 and 2&3 are each connected via NVLINK, and 4 is isolated.
Also, GPU 4 was recently installed, but the problem existed even before that.
The problem is not reproduced even on other multi-GPU machines that we have (although the hardware configurations are different), so it's OK if this issue is dismissed. As I said at the top, I'm just reporting in case this is of any interest to the developers.
I'm curious about your issue. If you run it many times, is it always the same GPUs that have an issue?
If so, can you try this: CUDA_VISIBLE_DEVICES=2,3,4,0,1 python your_script.py
This change the order of GPUs. Does the error continue to be on the same GPU?
Thanks for picking up this issue! Yes, with the example above it's always the third and fourth lines (GPUs 2 & 3) that have the issue.
On the other hand, indeed changing the order of the devices actually affects the symptom. With CUDA_VISIBLE_DEVICES=2,3,4,0,1 it was the fourth and fifth lines (GPUs 0 & 1) whose gradient became 1e-45.
This looks like a lead so I tested all 120 device order permutations. The result is:
- Whenever GPU 0 or 1 comes first, the gradients computed on 2 and 3 become corrupted.
- Whenever GPU 2 or 3 comes first, the gradients computed on 0 and 1 become corrupted.
- When GPU 4 comes first, all gradients are properly computed.
Also I realized I had forgotten to put the most basic information out: I'm seeing this effect in
- JAX 0.3.14 + CUDA 11.4.0
- JAX 0.3.15 + CUDA 11.6.2.
JAX was installed with
pip3 install "jax[cuda]==<version>" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Thanks for the results. What computer is this? 5 GPUs isn't a frequent config. What GPUs it is? Are they all the same? If not (like on DGX stations), if you use enable only the GPU that are the same, does it works well?
@yiiyama maybe you could also try debugging with nccl-tests
Thanks for the results. What computer is this? 5 GPUs isn't a frequent config. What GPUs it is? Are they all the same? If not (like on DGX stations), if you use enable only the GPU that are the same, does it works well?
@yiiyama Any update on that? If we do not have an update on 1 week, we will close this issue.
@nouiz I'm sorry for the long silence. I was trying nccl-tests as suggested by @sudhakarsingh27 and was trying to understand what I was seeing.
First, to answer your question: the machine is a custom-built server that originally came with 2 A100s. We bought and installed additional A100s to open PCIe slots. In fact we recently added two more, so the machine now has 7 A100s. All are the same product (80GB), but I can't rule out the possibility that there are slight internal differences, because they were purchased at different times.
@sudhakarsingh27 As for nccl-tests: I built the binaries as instructed in the README and ran all_reduce_perf. Commands
./all_reduce_perf -t 1
and
./all_reduce_perf -t 2
return some test results, but if I do -t 3 the test does not finish within at least 5 minutes. Do you know if that's normal, or is there indeed something wrong with our configuration?
If you limit yourself to only the 4 first GPU, does it work correctly? Also, what is the motherboard? Few motherboard can have 7 GPUs.
No, with CUDA_VISIBLE_DEVICES=0,1,2,3 I get a gradient 1e-45 for GPUs 2 and 3.
The motherboard is a SuperMicro X12DPG-OA6.
With the recent addition of two GPUs, the topology has changed from what I wrote above. I don't want to bombard you with too much information / drag this issue for too long (my intention was to check if JAX experts would have some quick idea about what I was seeing), but just in case it helps, the current output of nvidia-smi top -m is like this:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 CPU Affinity NUMA Affinity
GPU0 X NV12 PXB PXB PXB SYS SYS 0-31 0
GPU1 NV12 X PIX PXB PXB SYS SYS 0-31 0
GPU2 PXB PIX X PXB PXB SYS SYS 0-31 0
GPU3 PXB PXB PXB X NV12 SYS SYS 0-31 0
GPU4 PXB PXB PXB NV12 X SYS SYS 0-31 0
GPU5 SYS SYS SYS SYS SYS X PXB 32-63 1
GPU6 SYS SYS SYS SYS SYS PXB X 32-63 1
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
Seen in terms of PCIe connections, the output of lspci -t -v (with the GPU device ID added at the end of the lines) is
+-[0000:c9]-+-00.0 Intel Corporation Device 09a2
| +-00.1 Intel Corporation Device 09a4
| +-00.2 Intel Corporation Device 09a3
| +-00.4 Intel Corporation Device 0998
| \-02.0-[ca-da]----00.0-[cb-da]--+-00.0-[cc-ce]----00.0-[cd-ce]----10.0-[ce]----00.0 NVIDIA Corporation Device 20b5 ## <- This is GPU 5
| +-04.0-[cf-d2]----00.0-[d0-d2]--+-00.0-[d1]--
| | \-10.0-[d2]----00.0 NVIDIA Corporation Device 20b5 ## <- GPU 6
...
+-[0000:4a]-+-00.0 Intel Corporation Device 09a2
| +-00.1 Intel Corporation Device 09a4
| +-00.2 Intel Corporation Device 09a3
| +-00.4 Intel Corporation Device 0998
| \-02.0-[4b-5b]----00.0-[4c-5b]--+-00.0-[4d-4f]----00.0-[4e-4f]----10.0-[4f]----00.0 NVIDIA Corporation Device 20b5 ## <- GPU 0 (NVLINKed with GPU 1)
| +-04.0-[50-53]----00.0-[51-53]--+-00.0-[52]----00.0 NVIDIA Corporation Device 20b5 ## <- GPU 1 (NVLINKed with GPU 0)
| | \-10.0-[53]----00.0 NVIDIA Corporation Device 20b5 ## <- GPU 2
| +-08.0-[54-57]----00.0-[55-57]--+-00.0-[56]----00.0 NVIDIA Corporation Device 20b5 ## <- GPU 3 (NVLINKed with GPU 4)
| | \-10.0-[57]----00.0 NVIDIA Corporation Device 20b5 ## <- GPU 4 (NVLINKed with GPU 3)
But then again, I'm not entirely certain if device topology is the right place to look into. It does seem to have some relevance though, because even with the new topology the symptom is
- If an NVLINKed device (0, 1, 3, or 4) comes first in CUDA_VISIBLE_DEVICES, the gradient is correct in this device, the NVLINK partner (1, 0, 4, or 3), and devices that are connected via SYS (5 and 6), and incorrect in the other devices (connected via PIX or PXB).
- If GPU 2 (not NVLINKed to any) comes first, gradient is correct in 5 and 6, and incorrect in all other devices.
- If 5 or 6 comes first, the gradient is correct on all devices.
From this page: https://www.supermicro.com/en/support/resources/gpu?rsc=fltr_sku%3DSYS-420GP-TNR The A100 GPU isn't officially supported by this server. https://www.supermicro.com/en/products/system/GPU/4U/SYS-420GP-TNR
Sorry, I do not have a magic answer. Did you test other frameworks then JAX?
Hmm OK, so it may be a problem deep inside the GPU driver and the communication between the PCIe buses..
I myself use JAX exclusively, but others use PyTorch, and I haven't so far gotten complaints about unexpected behaviors. I will tell them to double-check their calculation results though.
I guess issue-wise we can't go any further then - thank you again for taking your time!
One last question to @sudhakarsingh27 - as I reported above, I couldn't really get any information out of nccl-tests. Was it expected that ./all_reduce_perf -t 3 takes forever (I killed the process because it ran for more than a couple of hours), or is that a sign that we do indeed have a hardware configuration issue?