jax icon indicating copy to clipboard operation
jax copied to clipboard

NCCL out of memory on dual 2080ti training

Open rwightman opened this issue 5 years ago • 3 comments

I'm working on training script (https://github.com/rwightman/efficientnet-jax/blob/master/tf_linen_train.py) based on Flax Linen ImageNet example (https://github.com/google/flax/blob/master/linen_examples/imagenet/imagenet_lib.py). It was working great on a system with 2 x Titan RTX. The same setup on 2 x 2080Ti either completely hung (kill required) or crashed with a non specific NCCL error (unhandled system error).

After some debugging, enabling NCCL_DEBUG=INFO, it appears that the NCCL init is hitting an OOM condition. It only logs the OOM as a warning, and eventually propagates as a non-specific error. See log dump below.

I'm running this within a docker container, NGC Tensorflow 2 20.10. It has NCCL 2.7.8 and TF 2.3.1, CUDA 11.1. JAX and FLAX were built / installed from git master as of today. The pip wheel jaxlib cuda11 also had the issue.

Setting env variable XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 fixed the problem on the 2080Ti system. I did not do an exhaustive search. Default is 0.9?

I'm somewhat confused as to why NCCL dumped a 2.7.3 version in the log when the container I'm running (and built jaxlib in) only has a 2.7.8 lib...

Log output

1de1c28a7548:559:653 [1] NCCL INFO Bootstrap : Using [0]eth0:172.17.0.2<0>
1de1c28a7548:559:653 [1] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
1de1c28a7548:559:653 [1] NCCL INFO NET/IB : No device found.
1de1c28a7548:559:653 [1] NCCL INFO NET/Socket : Using [0]eth0:172.17.0.2<0>
1de1c28a7548:559:653 [1] NCCL INFO Using network Socket
NCCL version 2.7.3+cudaCUDA_MAJOR.CUDA_MINOR
1de1c28a7548:559:1475 [0] NCCL INFO Channel 00/04 :    0   1
1de1c28a7548:559:1476 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
1de1c28a7548:559:1475 [0] NCCL INFO Channel 01/04 :    0   1
1de1c28a7548:559:1476 [1] NCCL INFO Trees [0] -1/-1/-1->1->0|0->1->-1/-1/-1 [1] 0/-1/-1->1->-1|-1->1->0/-1/-1 [2] -1/-1/-1->1->0|0->1->-1/-1/-1 [3] 0/-1/-1->1->-1|-1->1->0/-1/-1
1de1c28a7548:559:1475 [0] NCCL INFO Channel 02/04 :    0   1
1de1c28a7548:559:1475 [0] NCCL INFO Channel 03/04 :    0   1
1de1c28a7548:559:1475 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
1de1c28a7548:559:1475 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1|-1->0->1/-1/-1 [1] -1/-1/-1->0->1|1->0->-1/-1/-1 [2] 1/-1/-1->0->-1|-1->0->1/-1/-1 [3] -1/-1/-1->0->1|1->0->-1/-1/-1
1de1c28a7548:559:1476 [1] NCCL INFO Channel 00 : 1[2000] -> 0[1000] via P2P/direct pointer
1de1c28a7548:559:1475 [0] NCCL INFO Channel 00 : 0[1000] -> 1[2000] via P2P/direct pointer

1de1c28a7548:559:1475 [0] bazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs/alloc.h:41 NCCL WARN Cuda failure 'out of memory'
1de1c28a7548:559:1475 [0] NCCL INFO external/nccl_archive/src/transport/p2p.cc:184 -> 1
1de1c28a7548:559:1475 [0] NCCL INFO external/nccl_archive/src/transport.cc:30 -> 1
1de1c28a7548:559:1475 [0] NCCL INFO external/nccl_archive/src/transport.cc:49 -> 1
1de1c28a7548:559:1475 [0] NCCL INFO external/nccl_archive/src/init.cc:771 -> 1
1de1c28a7548:559:1475 [0] NCCL INFO external/nccl_archive/src/init.cc:845 -> 1
1de1c28a7548:559:1475 [0] NCCL INFO external/nccl_archive/src/group.cc:73 -> 1 [Async thread]
1de1c28a7548:559:1476 [1] NCCL INFO Channel 01 : 1[2000] -> 0[1000] via P2P/direct pointer
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying
1de1c28a7548:559:1476 [1] NCCL INFO Call to connect returned Connection refused, retrying

1de1c28a7548:559:1476 [1] bazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs/socket.h:403 NCCL WARN Connect to 172.17.0.2<51505> failed : Connection refused
1de1c28a7548:559:1476 [1] NCCL INFO external/nccl_archive/src/bootstrap.cc:95 -> 2
1de1c28a7548:559:1476 [1] NCCL INFO external/nccl_archive/src/bootstrap.cc:363 -> 2
1de1c28a7548:559:1476 [1] NCCL INFO external/nccl_archive/src/transport.cc:59 -> 2
1de1c28a7548:559:1476 [1] NCCL INFO external/nccl_archive/src/init.cc:771 -> 2
1de1c28a7548:559:1476 [1] NCCL INFO external/nccl_archive/src/init.cc:845 -> 2
1de1c28a7548:559:1476 [1] NCCL INFO external/nccl_archive/src/group.cc:73 -> 2 [Async thread]
2020-11-19 02:22:41.227591: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:265] Non-OK-status: status_ status: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc:311: NCCL operation ncclGroupEnd() failed: unhandled system error
Fatal Python error: Aborted

rwightman avatar Nov 19 '20 02:11 rwightman

We bundle NCCL with jaxlib, so that's why you see 2.7.3.

I guess there are two things we should do here: a) we should fail more gracefully on NCCL OOM. b) we should perhaps lower the default memory fraction.

hawkinsp avatar Nov 21 '20 16:11 hawkinsp

@hawkinsp the graceful failure / better reporting would definitely save some time arriving at a solution, without the NCCL logging on INFO, the 'unhandled system error' wasn't much to go on.

Othwerise I guess it's a matter of tuning default fractions. Maybe it should vary with total GPU memory, is there any other variable that specifies a min free instead of a fraction? The default fraction was fine with larger memory cards.

rwightman avatar Nov 21 '20 19:11 rwightman

@hawkinsp An update on this one. I rebuilt my containers with yesterday's jax/jaxlib master (rest of environment the same) after #5096 was resolved for me. The 2x2080Ti setup worked out of the box without requiring any mem fraction fiddling.

rwightman avatar Dec 11 '20 18:12 rwightman

@rwightman can we consider this resolved?

sudhakarsingh27 avatar Aug 12 '22 19:08 sudhakarsingh27

@sudhakarsingh27 yup

rwightman avatar Aug 12 '22 19:08 rwightman