accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

SLURM works well on one node 4 GPUs but it runs into GPU OOM error under multi-node multi-gpu setting

Open JunMa11 opened this issue 2 years ago • 1 comments

System Info

- `Accelerate` version: 0.20.0.dev0
- Platform: Linux-4.15.0-209-generic-x86_64-with-glibc2.27
- Python version: 3.10.11
- Numpy version: 1.24.3
- PyTorch version (GPU?): 2.0.1+cu117 (False)
- PyTorch XPU available: False
- System RAM: 47.16 GB
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: MULTI_GPU
        - mixed_precision: fp16
        - use_cpu: False
        - num_processes: 4
        - machine_rank: 0
        - num_machines: 1
        - gpu_ids: all
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • [X] My own task or dataset (give details below)

Reproduction

  1. The following SLURM works well on one node 4 GPUs (batch size 8)
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=24
#SBATCH --job-name=5nodes
#SBATCH --mem=200GB
#SBATCH --gres=gpu:4
#SBATCH --partition=a100
#SBATCH --output=logs/%x-%j.out
#SBATCH --error=logs/%x-%j.err
#SBATCH --time=5-00:00:00
#SBATCH --exclude=gpu182,gpu183,gpu184,gpu185,gpu186

set -x -e

# log the sbatch environment
echo "start time: $(date)"
echo "SLURM_JOBID="$SLURM_JOBID
echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST
echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION
echo "SLURM_NNODES"=$SLURM_NNODES
echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE
echo "SLURM_SUBMIT_DIR"=$SLURM_SUBMIT_DIR

# Training setup
GPUS_PER_NODE=4
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=5560
NNODES=$SLURM_NNODES
NODE_RANK=$SLURM_PROCID 
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

echo "MASTER_ADDR"=$MASTER_ADDR
echo "NNODES"=$NNODES
echo "NODE_RANK"=$NODE_RANK

export NCCL_DEBUG=INFO
export NCCL_IB_DISABLE=1


CMD=" \
    v2-1_mgpus_train_enc_dec.py \
    -task_name SAM-ViT-B-1node \
    "

LAUNCHER="accelerate launch \
    --multi_gpu \
    --mixed_precision=fp16 \
    --num_machines $NNODES \
    --num_processes $WORLD_SIZE \
    --main_process_ip "$MASTER_ADDR" \
    --main_process_port $MASTER_PORT \
    --machine_rank \$SLURM_PROCID \
    --role $SLURMD_NODENAME: \
    --rdzv_conf rdzv_backend=c10d \
    --max_restarts 1 \
    --tee 3 \
"

SRUN_ARGS=" \
    --wait=60 \
    --kill-on-bad-exit=1 \
    "

srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" # 2>&1

echo "END TIME: $(date)"

After changing to multi-node setting

#SBATCH --nodes=4
#SBATCH --ntasks=4

I got the following OOM error.

++ date
+ echo 'start time: Sat Jun  3 21:57:21 EDT 2023'
+ echo SLURM_JOBID=9767773
+ echo 'SLURM_JOB_NODELIST=gpu[187-188]'
+ echo SLURM_JOB_PARTITION=a100
+ echo SLURM_NNODES=2
+ echo SLURM_GPUS_ON_NODE=4
+ echo SLURM_SUBMIT_DIR=/ssd003/home/junma/MedSAM
+ GPUS_PER_NODE=4
++ scontrol show hostnames 'gpu[187-188]'
++ head -n 1
+ MASTER_ADDR=gpu187
+ MASTER_PORT=5565
+ NNODES=2
+ NODE_RANK=0
+ WORLD_SIZE=8
+ echo MASTER_ADDR=gpu187
+ echo NNODES=2
+ echo NODE_RANK=0
+ export NCCL_DEBUG=INFO
+ NCCL_DEBUG=INFO
+ export NCCL_IB_DISABLE=1
+ NCCL_IB_DISABLE=1
+ CMD='     v2-1_mgpus_train_enc_dec.py     -task_name MedSAM-ViT-B-2nodes     '
+ LAUNCHER='accelerate launch     --multi_gpu     --mixed_precision=fp16     --num_machines 2     --num_processes 8     --main_process_ip gpu187     --main_process_port 5565     --machine_rank $SLURM_PROCID     --role gpu187:     --rdzv_conf rdzv_backend=c10d     --max_restarts 1     --tee 3 '
+ SRUN_ARGS='     --wait=60     --kill-on-bad-exit=1     '
+ srun --wait=60 --kill-on-bad-exit=1 --jobid 9767773 bash -c 'accelerate launch     --multi_gpu     --mixed_precision=fp16     --num_machines 2     --num_processes 8     --main_process_ip gpu187     --main_process_port 5565     --machine_rank $SLURM_PROCID     --role gpu187:     --rdzv_conf rdzv_backend=c10d     --max_restarts 1     --tee 3       v2-1_mgpus_train_enc_dec.py     -task_name MedSAM-ViT-B-2nodes     '
[gpu187:0]:
[gpu187:1]:
[gpu187:1]:  0%|          | 0/125 [00:00<?, ?it/s][gpu187:3]:
[gpu187:3]:  0%|          | 0/125 [00:00<?, ?it/s][gpu187:2]:
[gpu187:0]:  0%|          | 0/125 [00:00<?, ?it/s][gpu187:2]:
[gpu187:2]:  0%|          | 0/125 [00:00<?, ?it/s][gpu187:0]:
[gpu187:2]:  0%|          | 0/125 [00:00<?, ?it/s][gpu187:1]:
[gpu187:1]:  0%|          | 0/125 [00:00<?, ?it/s][gpu187:3]:
[gpu187:3]:  0%|          | 0/125 [00:00<?, ?it/s][gpu187:2]:
[gpu187:2]:  0%|          | 0/125 [04:35<?, ?it/s]
[gpu187:2]:Traceback (most recent call last):
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:2]:    main()
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 261, in main
[gpu187:2]:    medsam_pred = medsam_model(image, boxes_np, gt2D)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 553, in forward
[gpu187:2]:    return model_forward(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 541, in __call__
[gpu187:2]:    return convert_to_fp32(self.model_forward(*args, **kwargs))
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
[gpu187:2]:    return func(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
[gpu187:2]:    output = self._run_ddp_forward(*inputs, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
[gpu187:2]:    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 193, in forward
[gpu187:2]:    image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 112, in forward
[gpu187:2]:    x = blk(x)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 174, in forward
[gpu187:2]:    x = self.attn(x)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 234, in forward
[gpu187:2]:    attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 358, in add_decomposed_rel_pos
[gpu187:2]:    attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
[gpu187:2]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 176.00 MiB (GPU 2; 79.35 GiB total capacity; 40.99 GiB already allocated; 155.69 MiB free; 41.46 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[gpu187:1]:
[gpu187:1]:  0%|          | 0/125 [04:35<?, ?it/s]
[gpu187:1]:Traceback (most recent call last):
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:1]:    main()
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 261, in main
[gpu187:1]:    medsam_pred = medsam_model(image, boxes_np, gt2D)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 553, in forward
[gpu187:1]:    return model_forward(*args, **kwargs)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 541, in __call__
[gpu187:1]:    return convert_to_fp32(self.model_forward(*args, **kwargs))
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
[gpu187:1]:    return func(*args, **kwargs)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
[gpu187:1]:    output = self._run_ddp_forward(*inputs, **kwargs)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
[gpu187:1]:    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 193, in forward
[gpu187:1]:    image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 112, in forward
[gpu187:1]:    x = blk(x)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 174, in forward
[gpu187:1]:    x = self.attn(x)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 234, in forward
[gpu187:1]:    attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 358, in add_decomposed_rel_pos
[gpu187:1]:    attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
[gpu187:1]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB (GPU 1; 79.35 GiB total capacity; 33.66 GiB already allocated; 735.69 MiB free; 34.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[gpu187:0]:
[gpu187:0]:  0%|          | 0/125 [04:36<?, ?it/s]
[gpu187:0]:Traceback (most recent call last):
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:0]:    main()
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 263, in main
[gpu187:0]:    accelerator.backward(loss)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1815, in backward
[gpu187:0]:    self.scaler.scale(loss).backward(**kwargs)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
[gpu187:0]:    torch.autograd.backward(
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
[gpu187:0]:    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[gpu187:0]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.00 GiB (GPU 0; 79.35 GiB total capacity; 58.53 GiB already allocated; 5.69 MiB free; 61.87 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[gpu187:0]:  0%|          | 0/125 [00:00<?, ?it/s][gpu187:2]:
[gpu187:2]:  0%|          | 0/125 [04:36<?, ?it/s]
[gpu187:2]:Traceback (most recent call last):
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:2]:    main()
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 261, in main
[gpu187:2]:    medsam_pred = medsam_model(image, boxes_np, gt2D)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 553, in forward
[gpu187:2]:    return model_forward(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 541, in __call__
[gpu187:2]:    return convert_to_fp32(self.model_forward(*args, **kwargs))
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
[gpu187:2]:    return func(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
[gpu187:2]:    output = self._run_ddp_forward(*inputs, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
[gpu187:2]:    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 193, in forward
[gpu187:2]:    image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 112, in forward
[gpu187:2]:    x = blk(x)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 174, in forward
[gpu187:2]:    x = self.attn(x)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 236, in forward
[gpu187:2]:    attn = attn.softmax(dim=-1)
[gpu187:2]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.00 GiB (GPU 2; 79.35 GiB total capacity; 33.56 GiB already allocated; 3.72 GiB free; 34.88 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[gpu187:3]:
[gpu187:3]:  0%|          | 0/125 [04:37<?, ?it/s]
[gpu187:3]:Traceback (most recent call last):
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:3]:    main()
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 261, in main
[gpu187:3]:    medsam_pred = medsam_model(image, boxes_np, gt2D)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 553, in forward
[gpu187:3]:    return model_forward(*args, **kwargs)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 541, in __call__
[gpu187:3]:    return convert_to_fp32(self.model_forward(*args, **kwargs))
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
[gpu187:3]:    return func(*args, **kwargs)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
[gpu187:3]:    output = self._run_ddp_forward(*inputs, **kwargs)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
[gpu187:0]:
[gpu187:0]:  0%|          | 0/125 [04:37<?, ?it/s]
[gpu187:0]:Traceback (most recent call last):
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:0]:    main()
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 261, in main
[gpu187:0]:    medsam_pred = medsam_model(image, boxes_np, gt2D)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:0]:    return forward_call(*args, **kwargs)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 553, in forward
[gpu187:0]:    return model_forward(*args, **kwargs)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 541, in __call__
[gpu187:0]:    return convert_to_fp32(self.model_forward(*args, **kwargs))
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
[gpu187:0]:    return func(*args, **kwargs)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
[gpu187:0]:    output = self._run_ddp_forward(*inputs, **kwargs)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
[gpu187:0]:    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:0]:    return forward_call(*args, **kwargs)
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 193, in forward
[gpu187:0]:    image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:0]:    return forward_call(*args, **kwargs)
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 112, in forward
[gpu187:0]:    x = blk(x)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:0]:    return forward_call(*args, **kwargs)
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 174, in forward
[gpu187:0]:    x = self.attn(x)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:0]:    return forward_call(*args, **kwargs)
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 229, in forward
[gpu187:0]:    q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
[gpu187:0]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 174.00 MiB (GPU 0; 79.35 GiB total capacity; 14.30 GiB already allocated; 5.69 MiB free; 14.56 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[gpu187:3]:
[gpu187:3]:  0%|          | 0/125 [04:37<?, ?it/s]
[gpu187:3]:Traceback (most recent call last):
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:3]:    main()
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 261, in main
[gpu187:3]:    medsam_pred = medsam_model(image, boxes_np, gt2D)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 553, in forward
[gpu187:3]:    return model_forward(*args, **kwargs)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 541, in __call__
[gpu187:3]:    return convert_to_fp32(self.model_forward(*args, **kwargs))
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
[gpu187:3]:    return func(*args, **kwargs)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
[gpu187:3]:    output = self._run_ddp_forward(*inputs, **kwargs)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
[gpu187:3]:    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 193, in forward
[gpu187:3]:    image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 112, in forward
[gpu187:3]:    x = blk(x)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 174, in forward
[gpu187:3]:    x = self.attn(x)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 234, in forward
[gpu187:3]:    attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 358, in add_decomposed_rel_pos
[gpu187:3]:    attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
[gpu187:3]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB (GPU 3; 79.35 GiB total capacity; 36.66 GiB already allocated; 735.69 MiB free; 37.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[gpu187:3]:    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 193, in forward
[gpu187:3]:    image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 112, in forward
[gpu187:3]:    x = blk(x)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 174, in forward
[gpu187:3]:    x = self.attn(x)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:3]:    return forward_call(*args, **kwargs)
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 234, in forward
[gpu187:3]:    attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 358, in add_decomposed_rel_pos
[gpu187:3]:    attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
[gpu187:3]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB (GPU 3; 79.35 GiB total capacity; 36.66 GiB already allocated; 735.69 MiB free; 37.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[gpu187:1]:
[gpu187:1]:  0%|          | 0/125 [04:37<?, ?it/s]
[gpu187:1]:Traceback (most recent call last):
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:1]:    main()
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 261, in main
[gpu187:1]:    medsam_pred = medsam_model(image, boxes_np, gt2D)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 553, in forward
[gpu187:1]:    return model_forward(*args, **kwargs)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 541, in __call__
[gpu187:1]:    return convert_to_fp32(self.model_forward(*args, **kwargs))
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
[gpu187:1]:    return func(*args, **kwargs)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
[gpu187:1]:    output = self._run_ddp_forward(*inputs, **kwargs)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
[gpu187:1]:    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 193, in forward
[gpu187:1]:    image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 112, in forward
[gpu187:1]:    x = blk(x)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 174, in forward
[gpu187:1]:    x = self.attn(x)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:1]:    return forward_call(*args, **kwargs)
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 234, in forward
[gpu187:1]:    attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 358, in add_decomposed_rel_pos
[gpu187:1]:    attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
[gpu187:1]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 176.00 MiB (GPU 1; 79.35 GiB total capacity; 40.99 GiB already allocated; 153.69 MiB free; 41.46 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 23237) of binary: /h/junma/anaconda3/envs/medsam/bin/python
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 23236) of binary: /h/junma/anaconda3/envs/medsam/bin/python

Note: reducing the batch size from 8 to 2 also doesn't work (OOM error); batch size = 1 works

Here is the python script.

#%% setup environment
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
from tqdm import tqdm
from skimage import transform
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import monai
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import argparse
import random
from datetime import datetime
import shutil
import glob
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

# set seeds
torch.manual_seed(2023)
# torch.cuda.empty_cache()
torch.distributed.init_process_group(backend="gloo") # default `nccl` doesn't work



#%% create a dataset class to load npz data and return back image embeddings and ground truth
class NpyDataset(Dataset): 
    def __init__(self, data_root, image_size=1024, bbox_shift=5, data_aug=False):
        self.data_root = data_root
        self.gt_path = join(data_root, 'gts')
        self.img_path = join(data_root, 'tr_imgs')
        self.gt_path_files = sorted(glob.glob(join(self.gt_path, '**/*.npy'), recursive=True))
        self.gt_path_files = [file for file in self.gt_path_files if os.path.isfile(join(self.img_path, os.path.basename(file)))]
        self.image_size = image_size
        self.bbox_shift = bbox_shift
        self.sam_transform = ResizeLongestSide(image_size)
        self.data_aug = data_aug
        print(f'number of images: {len(self.gt_path_files)}')
    
    def __len__(self):
        return len(self.gt_path_files)

    def __getitem__(self, index):
        img_name = os.path.basename(self.gt_path_files[index])
        img_3c = np.load(join(self.img_path, img_name), 'r', allow_pickle=True) # (H, W, 3)
        resize_img_skimg = transform.resize(img_3c, (self.image_size, self.image_size), order=3, preserve_range=True, mode='constant', anti_aliasing=True)
        resize_img_skimg_01 = (resize_img_skimg - resize_img_skimg.min()) / np.clip(resize_img_skimg.max() - resize_img_skimg.min(), a_min=1e-8, a_max=None) # normalize to [0, 1], (H, W, 3)
        # convert the shape to (3, H, W)
        img_1024 = np.transpose(resize_img_skimg_01, (2, 0, 1))
        assert np.max(img_1024)<=1.0 and np.min(img_1024)>=0.0, 'image should be normalized to [0, 1]'
        gt = np.load(self.gt_path_files[index], 'r', allow_pickle=True) # multiple labels [0, 1,4,5...], (256,256)
        assert img_name == os.path.basename(self.gt_path_files[index]), 'img gt name error' + self.gt_path_files[index] + self.npy_files[index]
        label_ids = np.unique(gt)[1:]
        gt2D = np.uint8(gt == random.choice(label_ids.tolist())) # only one label, (256, 256)
        # add data augmentation: random fliplr and random flipud
        if self.data_aug:
            if random.random() > 0.5:
                img_1024 = np.ascontiguousarray(np.flip(img_1024, axis=-1))
                gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-1))
            if random.random() > 0.5:
                img_1024 = np.ascontiguousarray(np.flip(img_1024, axis=-2))
                gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-2))

        assert np.max(gt2D)==1 and np.min(gt2D)==0.0, 'ground truth should be 0, 1'
        y_indices, x_indices = np.where(gt2D > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)
        # add perturbation to bounding box coordinates
        H, W = gt2D.shape
        x_min = max(0, x_min - random.randint(0, self.bbox_shift))
        x_max = min(W, x_max + random.randint(0, self.bbox_shift))
        y_min = max(0, y_min - random.randint(0, self.bbox_shift))
        y_max = min(H, y_max + random.randint(0, self.bbox_shift))
        bboxes = np.array([x_min, y_min, x_max, y_max])
        return torch.tensor(img_1024).float(), torch.tensor(gt2D[None, :,:]).long(), torch.tensor(bboxes).float(), img_name

# %% set up parser
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--tr_npy_path', type=str,
                    default='/scratch/ssd004/datasets/med-img-data/npy_MRCTv5',
                    help='path to training npy files; two subfolders: gts and imgs')
parser.add_argument('-task_name', type=str, default='MedSAM-ViT-B-MRCT-4GPUs')
parser.add_argument('-model_type', type=str, default='vit_b')
parser.add_argument('-checkpoint', type=str, default='work_dir/SAM/sam_vit_b_01ec64.pth')
# parser.add_argument('-device', type=str, default='cuda:0')
parser.add_argument('-work_dir', type=str, default='./work_dir')
parser.add_argument('-data_aug', type=bool, default=False, 
                    help='use data augmentation during training')
# train
parser.add_argument('-num_epochs', type=int, default=1000)
parser.add_argument('-batch_size', type=int, default=8)
parser.add_argument('-num_workers', type=int, default=24)
# Optimizer parameters
parser.add_argument('-weight_decay', type=float, default=0.01,
                    help='weight decay (default: 0.01)')
parser.add_argument('-lr', type=float, default=0.0001, metavar='LR',
                    help='learning rate (absolute lr)')
parser.add_argument('-use_wandb', type=bool, default=False, 
                    help='use wandb to monitor training')
args = parser.parse_args()

if args.use_wandb:
    import wandb
    wandb.login()
    wandb.init(project=args.task_name, config={"lr": args.lr, "batch_size": args.batch_size,
                                            "data_path": args.tr_npy_path,
                                            "model_type": args.model_type,
                                            })

# %% set up model for fine-tuning
# device = args.device
run_id = datetime.now().strftime("%Y%m%d-%H%M")
model_save_path = join(args.work_dir, args.task_name + '-' + run_id)
os.makedirs(model_save_path, exist_ok=True)
shutil.copyfile(__file__, join(model_save_path, run_id + '_' + os.path.basename(__file__)))

#%% define model

class MedSAM(nn.Module):
    def __init__(self, 
                image_encoder, 
                mask_decoder,
                prompt_encoder,
                input_img_size=1024,
                ):
        super().__init__()
        self.image_encoder = image_encoder
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder
        self.sam_trans = ResizeLongestSide(input_img_size)
        # freeze prompt encoder
        for param in self.prompt_encoder.parameters():
            param.requires_grad = False

    def forward(self, image, box_np, gt2D):
        image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
        # do not compute gradients for prompt encoder
        with torch.no_grad():
            # convert box to 1024x1024 grid
            # box_np = boxes.numpy()
            box = self.sam_trans.apply_boxes(box_np, (gt2D.shape[-2], gt2D.shape[-1]))
            box_torch = torch.as_tensor(box, dtype=torch.float32, device=gt2D.device)
            if len(box_torch.shape) == 2:
                box_torch = box_torch[:, None, :] # (B, 1, 4)

            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=None,
                boxes=box_torch,
                masks=None,
            )
        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embedding, # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
          )
        
        return low_res_masks



def main():
    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(gradient_accumulation_steps=8, kwargs_handlers=[kwargs])
    device = accelerator.device
    sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
    medsam_model = MedSAM(image_encoder=sam_model.image_encoder, 
                        mask_decoder=sam_model.mask_decoder,
                        prompt_encoder=sam_model.prompt_encoder,
                        input_img_size=sam_model.image_encoder.img_size,
                        ).to(device)
    

    print('Number of total parameters: ', sum(p.numel() for p in medsam_model.parameters())) # 93735472
    print('Number of trainable parameters: ', sum(p.numel() for p in medsam_model.parameters() if p.requires_grad)) # 93729252

    # only optimize the parameters of image encodder, mask decoder, do not update prompt encoder
    img_mask_encdec_params = list(medsam_model.image_encoder.parameters()) + list(medsam_model.mask_decoder.parameters())
    optimizer = torch.optim.AdamW(img_mask_encdec_params, lr=args.lr, weight_decay=args.weight_decay)
    print('Number of image encoder and mask decoder parameters: ', sum(p.numel() for p in img_mask_encdec_params if p.requires_grad)) # 93729252
    seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
    #%% train
    num_epochs = args.num_epochs
    iter_num = 0
    losses = []
    best_loss = 1e10
    train_dataset = NpyDataset(args.tr_npy_path, data_aug=args.data_aug)
    print('Number of training samples: ', len(train_dataset))
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                num_workers=args.num_workers, pin_memory=False)
    # prepare objects for accelerator
    medsam_model, optimizer, train_dataloader = accelerator.prepare(
        medsam_model, optimizer, train_dataloader
    )
    medsam_model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        for step, (image, gt2D, boxes, _) in enumerate(tqdm(train_dataloader)):
            with accelerator.accumulate(medsam_model):
                optimizer.zero_grad()
                boxes_np = boxes.detach().cpu().numpy()
                # image, gt2D = image.to(device), gt2D.to(device)
                medsam_pred = medsam_model(image, boxes_np, gt2D)
                loss = seg_loss(medsam_pred, gt2D)
                accelerator.backward(loss)
                optimizer.step()
            epoch_loss += loss.item()
            iter_num += 1
        epoch_loss /= step
        losses.append(epoch_loss)
        if args.use_wandb:
            wandb.log({"epoch_loss": epoch_loss})
        print(f'EPOCH: {epoch}, Loss: {epoch_loss}')
        # save the model checkpoint
        checkpoint = {'model': medsam_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch}
        torch.save(checkpoint, join(model_save_path, 'medsam_model_latest.pth'))
        # save the best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(checkpoint, join(model_save_path, 'medsam_model_best.pth'))

if __name__ == "__main__":
    main()

Expected behavior

Multi-node training doesn't have GPU OOM error since it works well on one node setting.

My issue is similar to this one: https://github.com/huggingface/accelerate/issues/1443

but the current solution doesn't work for me.

Any comments or suggestions are highly appreciated.

JunMa11 avatar Jun 04 '23 02:06 JunMa11

I also tried to set split_batches=True and adjust the batch size to 64 (two nodes and each node has 4 gpus).

The training process iterates few steps and runs into OOM again.

++ date
+ echo 'start time: Sat Jun  3 22:38:26 EDT 2023'
+ echo SLURM_JOBID=9767796
+ echo 'SLURM_JOB_NODELIST=gpu[187-188]'
+ echo SLURM_JOB_PARTITION=a100
+ echo SLURM_NNODES=2
+ echo SLURM_GPUS_ON_NODE=4
+ echo SLURM_SUBMIT_DIR=/ssd003/home/junma/MedSAM
+ GPUS_PER_NODE=4
++ head -n 1
++ scontrol show hostnames 'gpu[187-188]'
+ MASTER_ADDR=gpu187
+ MASTER_PORT=5566
+ NNODES=2
+ NODE_RANK=0
+ WORLD_SIZE=8
+ echo MASTER_ADDR=gpu187
+ echo NNODES=2
+ echo NODE_RANK=0
+ export NCCL_DEBUG=INFO
+ NCCL_DEBUG=INFO
+ export NCCL_IB_DISABLE=1
+ NCCL_IB_DISABLE=1
+ CMD='     v2-1_mgpus_train_enc_dec.py     -task_name SAM-ViT-B-2nodes     '
+ LAUNCHER='accelerate launch     --multi_gpu     --mixed_precision=fp16     --num_machines 2     --num_processes 8     --main_process_ip gpu187     --main_process_port 5566     --machine_rank $SLURM_PROCID     --role gpu187:     --rdzv_conf rdzv_backend=c10d     --max_restarts 1     --tee 3 '
+ SRUN_ARGS='     --wait=60     --kill-on-bad-exit=1     '
+ srun --wait=60 --kill-on-bad-exit=1 --jobid 9767796 bash -c 'accelerate launch     --multi_gpu     --mixed_precision=fp16     --num_machines 2     --num_processes 8     --main_process_ip gpu187     --main_process_port 5566     --machine_rank $SLURM_PROCID     --role gpu187:     --rdzv_conf rdzv_backend=c10d     --max_restarts 1     --tee 3       v2-1_mgpus_train_enc_dec.py     -task_name MedSAM-ViT-B-2nodes     '
[gpu187:0]:
[gpu187:2]:
[gpu187:2]:  0%|          | 0/127 [00:00<?, ?it/s][gpu187:0]:
[gpu187:0]:  0%|          | 0/127 [00:00<?, ?it/s][gpu187:2]:
[gpu187:0]:  0%|          | 0/127 [00:00<?, ?it/s][gpu187:1]:
[gpu187:2]:  0%|          | 0/127 [00:00<?, ?it/s][gpu187:3]:
[gpu187:3]:  0%|          | 0/127 [00:00<?, ?it/s][gpu187:1]:
[gpu187:1]:  0%|          | 0/127 [00:00<?, ?it/s][gpu187:3]:
[gpu187:3]:  0%|          | 0/127 [00:00<?, ?it/s][gpu187:0]:
[gpu187:0]:  1%|          | 1/127 [02:46<5:50:10, 166.75s/it][gpu187:0]:
[gpu187:0]:  2%|▏         | 2/127 [02:48<2:25:13, 69.71s/it] [gpu187:0]:
[gpu187:0]:  2%|▏         | 3/127 [02:50<1:19:44, 38.58s/it][gpu187:0]:
[gpu187:0]:  3%|▎         | 4/127 [02:51<49:08, 23.98s/it]  [gpu187:0]:
[gpu187:0]:  4%|▍         | 5/127 [02:53<32:16, 15.88s/it][gpu187:0]:
[gpu187:0]:  5%|▍         | 6/127 [02:54<22:16, 11.04s/it][gpu187:0]:
[gpu187:1]:  0%|          | 0/127 [00:00<?, ?it/s][gpu187:2]:
[gpu187:2]:  0%|          | 0/127 [04:35<?, ?it/s]
[gpu187:2]:Traceback (most recent call last):
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:2]:    main()
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 261, in main
[gpu187:2]:    medsam_pred = medsam_model(image, boxes_np, gt2D)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 553, in forward
[gpu187:2]:    return model_forward(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 541, in __call__
[gpu187:2]:    return convert_to_fp32(self.model_forward(*args, **kwargs))
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
[gpu187:2]:    return func(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
[gpu187:2]:    output = self._run_ddp_forward(*inputs, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
[gpu187:2]:    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 193, in forward
[gpu187:2]:    image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 112, in forward
[gpu187:2]:    x = blk(x)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 174, in forward
[gpu187:2]:    x = self.attn(x)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 234, in forward
[gpu187:2]:    attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 358, in add_decomposed_rel_pos
[gpu187:2]:    attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
[gpu187:2]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB (GPU 2; 79.35 GiB total capacity; 33.66 GiB already allocated; 709.69 MiB free; 34.91 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[gpu187:0]:  6%|▌         | 7/127 [02:56<15:52,  7.93s/it][gpu187:2]:
[gpu187:2]:  0%|          | 0/127 [04:35<?, ?it/s]
[gpu187:2]:Traceback (most recent call last):
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:2]:    main()
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 261, in main
[gpu187:2]:    medsam_pred = medsam_model(image, boxes_np, gt2D)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 553, in forward
[gpu187:2]:    return model_forward(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/utils/operations.py", line 541, in __call__
[gpu187:2]:    return convert_to_fp32(self.model_forward(*args, **kwargs))
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
[gpu187:2]:    return func(*args, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
[gpu187:2]:    output = self._run_ddp_forward(*inputs, **kwargs)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
[gpu187:2]:    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 193, in forward
[gpu187:2]:    image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 112, in forward
[gpu187:2]:    x = blk(x)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 174, in forward
[gpu187:2]:    x = self.attn(x)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
[gpu187:2]:    return forward_call(*args, **kwargs)
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 234, in forward
[gpu187:2]:    attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/segment_anything/modeling/image_encoder.py", line 358, in add_decomposed_rel_pos
[gpu187:2]:    attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
[gpu187:2]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 176.00 MiB (GPU 2; 79.35 GiB total capacity; 40.99 GiB already allocated; 125.69 MiB free; 41.46 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 13378 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 13380 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 13379 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 13381 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 13385 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 13384 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 2 (pid: 13382) of binary: /h/junma/anaconda3/envs/medsam/bin/python
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 2 (pid: 13383) of binary: /h/junma/anaconda3/envs/medsam/bin/python
[gpu187:0]:Traceback (most recent call last):
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:0]:    main()
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 229, in main
[gpu187:0]:    ).to(device)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1145, in to
[gpu187:0]:    return self._apply(convert)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 797, in _apply
[gpu187:0]:    module._apply(fn)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 797, in _apply
[gpu187:0]:    module._apply(fn)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 797, in _apply
[gpu187:0]:    module._apply(fn)
[gpu187:0]:  [Previous line repeated 2 more times]
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 820, in _apply
[gpu187:0]:    param_applied = fn(param)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1143, in convert
[gpu187:0]:    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
[gpu187:0]:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 79.35 GiB total capacity; 227.85 MiB already allocated; 19.69 MiB free; 262.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 19402 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 19403 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 19404 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 19401) of binary: /h/junma/anaconda3/envs/medsam/bin/python
[gpu187:3]:Traceback (most recent call last):
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:3]:    main()
[gpu187:3]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 250, in main
[gpu187:3]:    medsam_model, optimizer, train_dataloader = accelerator.prepare(
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1203, in prepare
[gpu187:3]:    result = tuple(
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1204, in <genexpr>
[gpu187:3]:    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1043, in _prepare_one
[gpu187:3]:    return self.prepare_model(obj, device_placement=device_placement)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1287, in prepare_model
[gpu187:3]:    model = torch.nn.parallel.DistributedDataParallel(
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
[gpu187:3]:    _verify_param_shape_across_processes(self.process_group, parameters)
[gpu187:3]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
[gpu187:3]:    return dist._verify_params_across_processes(process_group, tensors, logger)
[gpu187:3]:RuntimeError: [../third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer [172.17.8.187]:11142
[gpu187:0]:Traceback (most recent call last):
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:0]:    main()
[gpu187:0]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 250, in main
[gpu187:0]:    medsam_model, optimizer, train_dataloader = accelerator.prepare(
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1203, in prepare
[gpu187:0]:    result = tuple(
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1204, in <genexpr>
[gpu187:0]:    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1043, in _prepare_one
[gpu187:0]:    return self.prepare_model(obj, device_placement=device_placement)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1287, in prepare_model
[gpu187:0]:    model = torch.nn.parallel.DistributedDataParallel(
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
[gpu187:0]:    _verify_param_shape_across_processes(self.process_group, parameters)
[gpu187:0]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
[gpu187:0]:    return dist._verify_params_across_processes(process_group, tensors, logger)
[gpu187:0]:RuntimeError: [../third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer [172.17.8.187]:45478
[gpu187:2]:Traceback (most recent call last):
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:2]:    main()
[gpu187:2]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 250, in main
[gpu187:2]:    medsam_model, optimizer, train_dataloader = accelerator.prepare(
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1203, in prepare
[gpu187:2]:    result = tuple(
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1204, in <genexpr>
[gpu187:2]:    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1043, in _prepare_one
[gpu187:2]:    return self.prepare_model(obj, device_placement=device_placement)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1287, in prepare_model
[gpu187:2]:    model = torch.nn.parallel.DistributedDataParallel(
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
[gpu187:2]:    _verify_param_shape_across_processes(self.process_group, parameters)
[gpu187:2]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
[gpu187:2]:    return dist._verify_params_across_processes(process_group, tensors, logger)
[gpu187:2]:RuntimeError: [../third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer [172.17.8.187]:61606
[gpu187:1]:Traceback (most recent call last):
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 292, in <module>
[gpu187:1]:    main()
[gpu187:1]:  File "/ssd003/home/junma/MedSAM/v2-1_mgpus_train_enc_dec.py", line 250, in main
[gpu187:1]:    medsam_model, optimizer, train_dataloader = accelerator.prepare(
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1203, in prepare
[gpu187:1]:    result = tuple(
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1204, in <genexpr>
[gpu187:1]:    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1043, in _prepare_one
[gpu187:1]:    return self.prepare_model(obj, device_placement=device_placement)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/accelerator.py", line 1287, in prepare_model
[gpu187:1]:    model = torch.nn.parallel.DistributedDataParallel(
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
[gpu187:1]:    _verify_param_shape_across_processes(self.process_group, parameters)
[gpu187:1]:  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
[gpu187:1]:    return dist._verify_params_across_processes(process_group, tensors, logger)
[gpu187:1]:RuntimeError: [../third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer [172.17.8.187]:13704
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 19414) of binary: /h/junma/anaconda3/envs/medsam/bin/python
Traceback (most recent call last):
  File "/h/junma/anaconda3/envs/medsam/bin/accelerate", line 8, in <module>
Traceback (most recent call last):
  File "/h/junma/anaconda3/envs/medsam/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    sys.exit(main())
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/commands/launch.py", line 932, in launch_command
    args.func(args)
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/commands/launch.py", line 932, in launch_command
    multi_gpu_launcher(args)
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/commands/launch.py", line 627, in multi_gpu_launcher
    multi_gpu_launcher(args)
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/accelerate/commands/launch.py", line 627, in multi_gpu_launcher
    distrib_run.run(args)
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/run.py", line 785, in run
    distrib_run.run(args)
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    elastic_launch(
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/h/junma/anaconda3/envs/medsam/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
v2-1_mgpus_train_enc_dec.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-06-03_22:45:52
  host      : gpu187.cluster.local
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 19415)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2023-06-03_22:45:52
  host      : gpu187.cluster.local
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 19416)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2023-06-03_22:45:52
  host      : gpu187.cluster.local
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 19418)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-06-03_22:45:52
  host      : gpu187.cluster.local
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 19414)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
v2-1_mgpus_train_enc_dec.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-06-03_22:45:01
  host      : gpu187.cluster.local
  rank      : 4 (local_rank: 0)
  exitcode  : 1 (pid: 19401)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
============================================================
srun: error: gpu187: tasks 0-1: Exited with exit code 1
srun: launch/slurm: _step_signal: Terminating StepId=9767796.0

JunMa11 avatar Jun 04 '23 03:06 JunMa11

Hi @JunMa11 ,

Thanks for your provided slurm.sh. Based on it, I successfully adapted my codes into multi-node training. However, there is a strange phenomenon that doubts me a lot. When I leverage the following code snippets, it goes well:

LAUNCHER="accelerate launch \
    --multi_gpu \
    --mixed_precision=fp16 \
    --num_machines $NNODES \
    --num_processes $WORLD_SIZE \
    --main_process_ip "$MASTER_ADDR" \
    --main_process_port $MASTER_PORT \
    --machine_rank \$SLURM_PROCID \
    --role $SLURMD_NODENAME: \
    --rdzv_conf rdzv_backend=c10d \
    --max_restarts 1 \
    --tee 3 \
"

However, if I remove the double quotes on the variable $MASTER_ADDR, it fails. Since I'm not very familiar with Linux bash, I wonder what is the true effect of the double quotes here?

Hope to hear from you!

Thanks.

========================================

Sorry, it seems there is some latency in my devices. Whether the double quotes exist has no influence on the program running.

WindVChen avatar Jun 15 '23 09:06 WindVChen

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jul 09 '23 15:07 github-actions[bot]

Hi @WindVChen ,

We finally switch to using the naive torch implementation.

Here is the slurm script:

https://github.com/bowang-lab/MedSAM/blob/main/train_multi_gpus.sh

Hope it helps.

JunMa11 avatar Jul 09 '23 16:07 JunMa11