jax icon indicating copy to clipboard operation
jax copied to clipboard

Bogus gradient value from value_and_grad for committed DeviceArray on a multi-GPU host

Open yiiyama opened this issue 3 years ago • 6 comments

Hi,

I'm reporting this issue with a disclaimer that it's probably not reproducible (sorry!) but hoping it's of some worth for the awesome JAX developers:

We have a machine with 5 NVIDIA A100 (80GB) cards, and we see that the gradient computed by value_and_grad is wrong when DeviceArray committed two specific devices are passed as the function argument. Here is the simplest example:

import jax

fn = jax.value_and_grad(lambda x: x * 1)

for i in range(jax.device_count()):
    x = jax.device_put(1., device=jax.devices()[i])
    v, g = fn(x)
    print(v, g)

Printout:

1.0 1.0
1.0 1.0
1.0 1e-45
1.0 1e-45
1.0 1.0

When the function is lambda x: x or lambda x: x * 1. (note the dot after 1 - multiplying by a float here) the result from all devices are correct (1.0 1.0).

Also, if the call to fn is put in the default_device context, as in:

import jax

fn = jax.value_and_grad(lambda x: x * 1)

for i in range(jax.device_count()):
    x = jax.device_put(1., device=jax.devices()[i])
    with jax.default_device(jax.devices()[i]):
        v, g = fn(x)
    print(v, g)

the output gradient is correct for all devices. Which makes me suspect that uncommitted arrays are involved somewhere in the gradient calculation, even when the input is committed. And also that's only in some specific cases.

The only information that may be even remotely relevant that I can think of is the device topology:

$ nvidia-smi topo -m
	GPU0	GPU1	GPU2	GPU3	GPU4	CPU Affinity	NUMA Affinity
GPU0	 X 	NV12	PXB	PXB	SYS	0-31	0
GPU1	NV12	 X 	PXB	PXB	SYS	0-31	0
GPU2	PXB	PXB	 X 	NV12	SYS	0-31	0
GPU3	PXB	PXB	NV12	 X 	SYS	0-31	0
GPU4	SYS	SYS	SYS	SYS	 X 	32-63	1

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

That is, device pairs 0&1 and 2&3 are each connected via NVLINK, and 4 is isolated.

Also, GPU 4 was recently installed, but the problem existed even before that.

The problem is not reproduced even on other multi-GPU machines that we have (although the hardware configurations are different), so it's OK if this issue is dismissed. As I said at the top, I'm just reporting in case this is of any interest to the developers.

yiiyama avatar Jul 13 '22 08:07 yiiyama

I'm curious about your issue. If you run it many times, is it always the same GPUs that have an issue?

If so, can you try this: CUDA_VISIBLE_DEVICES=2,3,4,0,1 python your_script.py

This change the order of GPUs. Does the error continue to be on the same GPU?

nouiz avatar Jul 29 '22 21:07 nouiz

Thanks for picking up this issue! Yes, with the example above it's always the third and fourth lines (GPUs 2 & 3) that have the issue.

On the other hand, indeed changing the order of the devices actually affects the symptom. With CUDA_VISIBLE_DEVICES=2,3,4,0,1 it was the fourth and fifth lines (GPUs 0 & 1) whose gradient became 1e-45.

This looks like a lead so I tested all 120 device order permutations. The result is:

  • Whenever GPU 0 or 1 comes first, the gradients computed on 2 and 3 become corrupted.
  • Whenever GPU 2 or 3 comes first, the gradients computed on 0 and 1 become corrupted.
  • When GPU 4 comes first, all gradients are properly computed.

yiiyama avatar Aug 01 '22 05:08 yiiyama

Also I realized I had forgotten to put the most basic information out: I'm seeing this effect in

  • JAX 0.3.14 + CUDA 11.4.0
  • JAX 0.3.15 + CUDA 11.6.2.

JAX was installed with

pip3 install "jax[cuda]==<version>" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

yiiyama avatar Aug 01 '22 05:08 yiiyama

Thanks for the results. What computer is this? 5 GPUs isn't a frequent config. What GPUs it is? Are they all the same? If not (like on DGX stations), if you use enable only the GPU that are the same, does it works well?

nouiz avatar Aug 01 '22 15:08 nouiz

@yiiyama maybe you could also try debugging with nccl-tests

sudhakarsingh27 avatar Aug 01 '22 20:08 sudhakarsingh27

Thanks for the results. What computer is this? 5 GPUs isn't a frequent config. What GPUs it is? Are they all the same? If not (like on DGX stations), if you use enable only the GPU that are the same, does it works well?

@yiiyama Any update on that? If we do not have an update on 1 week, we will close this issue.

nouiz avatar Aug 11 '22 18:08 nouiz

@nouiz I'm sorry for the long silence. I was trying nccl-tests as suggested by @sudhakarsingh27 and was trying to understand what I was seeing.

First, to answer your question: the machine is a custom-built server that originally came with 2 A100s. We bought and installed additional A100s to open PCIe slots. In fact we recently added two more, so the machine now has 7 A100s. All are the same product (80GB), but I can't rule out the possibility that there are slight internal differences, because they were purchased at different times.

@sudhakarsingh27 As for nccl-tests: I built the binaries as instructed in the README and ran all_reduce_perf. Commands

./all_reduce_perf -t 1

and

./all_reduce_perf -t 2

return some test results, but if I do -t 3 the test does not finish within at least 5 minutes. Do you know if that's normal, or is there indeed something wrong with our configuration?

yiiyama avatar Aug 18 '22 15:08 yiiyama

If you limit yourself to only the 4 first GPU, does it work correctly? Also, what is the motherboard? Few motherboard can have 7 GPUs.

nouiz avatar Aug 19 '22 20:08 nouiz

No, with CUDA_VISIBLE_DEVICES=0,1,2,3 I get a gradient 1e-45 for GPUs 2 and 3.

The motherboard is a SuperMicro X12DPG-OA6.

With the recent addition of two GPUs, the topology has changed from what I wrote above. I don't want to bombard you with too much information / drag this issue for too long (my intention was to check if JAX experts would have some quick idea about what I was seeing), but just in case it helps, the current output of nvidia-smi top -m is like this:

	GPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	CPU Affinity	NUMA Affinity
GPU0	 X 	NV12	PXB	PXB	PXB	SYS	SYS	0-31	0
GPU1	NV12	 X 	PIX	PXB	PXB	SYS	SYS	0-31	0
GPU2	PXB	PIX	 X 	PXB	PXB	SYS	SYS	0-31	0
GPU3	PXB	PXB	PXB	 X 	NV12	SYS	SYS	0-31	0
GPU4	PXB	PXB	PXB	NV12	 X 	SYS	SYS	0-31	0
GPU5	SYS	SYS	SYS	SYS	SYS	 X 	PXB	32-63	1
GPU6	SYS	SYS	SYS	SYS	SYS	PXB	 X 	32-63	1

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Seen in terms of PCIe connections, the output of lspci -t -v (with the GPU device ID added at the end of the lines) is

 +-[0000:c9]-+-00.0  Intel Corporation Device 09a2
 |           +-00.1  Intel Corporation Device 09a4
 |           +-00.2  Intel Corporation Device 09a3
 |           +-00.4  Intel Corporation Device 0998
 |           \-02.0-[ca-da]----00.0-[cb-da]--+-00.0-[cc-ce]----00.0-[cd-ce]----10.0-[ce]----00.0  NVIDIA Corporation Device 20b5  ## <- This is GPU 5
 |                                           +-04.0-[cf-d2]----00.0-[d0-d2]--+-00.0-[d1]--
 |                                           |                               \-10.0-[d2]----00.0  NVIDIA Corporation Device 20b5  ## <- GPU 6
...
 +-[0000:4a]-+-00.0  Intel Corporation Device 09a2
 |           +-00.1  Intel Corporation Device 09a4
 |           +-00.2  Intel Corporation Device 09a3
 |           +-00.4  Intel Corporation Device 0998
 |           \-02.0-[4b-5b]----00.0-[4c-5b]--+-00.0-[4d-4f]----00.0-[4e-4f]----10.0-[4f]----00.0  NVIDIA Corporation Device 20b5  ## <- GPU 0 (NVLINKed with GPU 1)
 |                                           +-04.0-[50-53]----00.0-[51-53]--+-00.0-[52]----00.0  NVIDIA Corporation Device 20b5  ## <- GPU 1 (NVLINKed with GPU 0)
 |                                           |                               \-10.0-[53]----00.0  NVIDIA Corporation Device 20b5  ## <- GPU 2
 |                                           +-08.0-[54-57]----00.0-[55-57]--+-00.0-[56]----00.0  NVIDIA Corporation Device 20b5  ## <- GPU 3 (NVLINKed with GPU 4)
 |                                           |                               \-10.0-[57]----00.0  NVIDIA Corporation Device 20b5  ## <- GPU 4 (NVLINKed with GPU 3)

But then again, I'm not entirely certain if device topology is the right place to look into. It does seem to have some relevance though, because even with the new topology the symptom is

  • If an NVLINKed device (0, 1, 3, or 4) comes first in CUDA_VISIBLE_DEVICES, the gradient is correct in this device, the NVLINK partner (1, 0, 4, or 3), and devices that are connected via SYS (5 and 6), and incorrect in the other devices (connected via PIX or PXB).
  • If GPU 2 (not NVLINKed to any) comes first, gradient is correct in 5 and 6, and incorrect in all other devices.
  • If 5 or 6 comes first, the gradient is correct on all devices.

yiiyama avatar Aug 22 '22 02:08 yiiyama

From this page: https://www.supermicro.com/en/support/resources/gpu?rsc=fltr_sku%3DSYS-420GP-TNR The A100 GPU isn't officially supported by this server. https://www.supermicro.com/en/products/system/GPU/4U/SYS-420GP-TNR

Sorry, I do not have a magic answer. Did you test other frameworks then JAX?

nouiz avatar Aug 23 '22 20:08 nouiz

Hmm OK, so it may be a problem deep inside the GPU driver and the communication between the PCIe buses..

I myself use JAX exclusively, but others use PyTorch, and I haven't so far gotten complaints about unexpected behaviors. I will tell them to double-check their calculation results though.

I guess issue-wise we can't go any further then - thank you again for taking your time!

One last question to @sudhakarsingh27 - as I reported above, I couldn't really get any information out of nccl-tests. Was it expected that ./all_reduce_perf -t 3 takes forever (I killed the process because it ran for more than a couple of hours), or is that a sign that we do indeed have a hardware configuration issue?

yiiyama avatar Aug 24 '22 02:08 yiiyama