improved-diffusion
improved-diffusion copied to clipboard
mpiexec running out of memory in multi-GPU
Hello, I am having an issue in using mpiexec to distribute the training.
It seems that I can run training on a single GPU using the following parameters:
MODEL_FLAGS="--image_size 256 --num_channels 128 --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16" DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False" TRAIN_FLAGS="--lr 1e-4 --batch_size 4 --microbatch 1" export OPENAI_LOGDIR="/mnt/storage/Jason_improved_diffusion/experiments/test_256/" python scripts/image_train.py $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
But when I try to do distributed training with:
MODEL_FLAGS="--image_size 256 --num_channels 128 --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16" DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False" TRAIN_FLAGS="--lr 1e-4 --batch_size 1 --microbatch 1" export OPENAI_LOGDIR="/mnt/storage/Jason_improved_diffusion/experiments/test_256/" export CUDA_VISIBLE_DEVICES=1,2,3 mpiexec -n 3 python scripts/image_train.py $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
It always comes back with an out-of-memory error..
I believe that there's some memory overhead in the one GPU for the distributed processes to communicate with each other and I think that's what's causing the out-of-memory error.
I'm not sure if this is the right way to put it but is there a way for me to set the one GPU as the communicator and run only run the distributed training on other GPUs? E.g. If I have three GPUs, I'll use a batch size of 2 so that the first GPU doesn't have to store memory of the data.
Below are the dependencies and versions that I believe are relevant:
PyTorch 1.10.0+cu111 CUDA version 11.4 NVIDIA GeForce RTX 3070 (for all four GPUs) mpi4py 3.0.3
hello, have you solved this question?
I'm experiencing the same problem
Same issue here, how did you resolve it guys? @jeong-jasonji @99-WSJ @akrlowicz