vq-vae-2-pytorch icon indicating copy to clipboard operation
vq-vae-2-pytorch copied to clipboard

Stuck at epoch 1 iter 1 when train vqvae with multi-gpu

Open JamesZhutheThird opened this issue 3 years ago • 11 comments
trafficstars

The single-gpu training process works just fine for me, and the output samples are satisfactory. However when I set --n_gpu 2 or --n_gpu 4, the training process will get stuck at the beginning(Epoch 1 Iter 1). And the time cost of this very first iter (34 seconds) is much longer than that in single-gpu training (3 seconds per iter).

图片

I would be grateful if someone could help me to see what might be wrong with this.

JamesZhutheThird avatar Jul 24 '22 16:07 JamesZhutheThird

I also have the same problem, have you worked out ?

xuyanging avatar Feb 17 '23 10:02 xuyanging

Well, I think I solved this by adding torch.cuda.empty_cache() at the end of each iteration.

JamesZhutheThird avatar Feb 17 '23 12:02 JamesZhutheThird

Well, I think I solved this by adding torch.cuda.empty_cache() at the end of each iteration.

May I ask where did you add torch.cuda.empty_cache()? I am experiencing the same problem. Thank you.

ekyy2 avatar Mar 10 '23 07:03 ekyy2

Original

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]
            # ...
            if i % 100 == 0:
                model.eval()
                # ...
                with torch.no_grad():
                    out, _ = model(sample)
                # ...
                model.train()

Fixed

        if dist.is_primary():
            with torch.no_grad():
                lr = optimizer.param_groups[0]["lr"]
                # ...
                if i % 100 == 0:
                    model.eval()
                    # ...
                    out, _ = model(sample)
                    # ...
                    torch.cuda.empty_cache()

JamesZhutheThird avatar Mar 10 '23 07:03 JamesZhutheThird

Original

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]
            # ...
            if i % 100 == 0:
                model.eval()
                # ...
                with torch.no_grad():
                    out, _ = model(sample)
                # ...
                model.train()

Fixed

        if dist.is_primary():
            with torch.no_grad():
                lr = optimizer.param_groups[0]["lr"]
                # ...
                if i % 100 == 0:
                    model.eval()
                    # ...
                    out, _ = model(sample)
                    # ...
                    torch.cuda.empty_cache()

@JamesZhutheThird Thank you so much for your quick reply. I believe that the model is still stuck at calculating recon_loss.item() for the part_mse_sum after I made the following changes:

    'if dist.is_primary():
        with torch.no_grad():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description(
                (
                    f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                    f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                    f"lr: {lr:.5f}"
                )
            )

            if i % 100 == 0:
                model.eval()

                sample = img[:sample_size]

                # with torch.no_grad():
                out, _ = model(sample)

                utils.save_image(
                    torch.cat([sample, out], 0),
                    f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png",
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                model.train()
                torch.cuda.empty_cache()`

I am not sure whether I understood the changes to be made correctly, so your advice would be much appreciated. Thanks.

ekyy2 avatar Mar 13 '23 02:03 ekyy2

perhaps you can try moving model.train() to the begining of each iter

    for i, (img, label) in enumerate(loader):
        model.train()
        model.zero_grad()

I don't see anything else different from my code, but since I have changed other parts of the code and added extra functions, I'm not sure whether they are related to this bug. Anyway, please keep in touch with me about this. @ekyy2

JamesZhutheThird avatar Mar 13 '23 06:03 JamesZhutheThird

I tried the change you suggested but it does not seem to work. No worries. As you said, it could be something else. Maybe it could be that I am using Cuda 11.6, but using the PyTorch combination for Cuda 11.3 (pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113). Another is that the blowing up of the latent error is mentioned here: https://github.com/rosinality/vq-vae-2-pytorch/issues/65#issue-870480215, but it is weird that it only happens for multiple GPUs. Anyways, keep in touch.

ekyy2 avatar Mar 14 '23 03:03 ekyy2

python                3.9.13
torch                   1.11.0+cu113
torchvision          0.12.0+cu113
python                3.9.13 
torch                   1.12.0+cu113
torchaudio          0.12.0+cu113
torchvision          0.13.0+cu113

I've tested on two environments with different pytorch versions with CUDA11.2. I am pretty sure the exact versions are not so important.

btw I checked files in ./distributed and they are the same as those in this repo.

JamesZhutheThird avatar Mar 14 '23 06:03 JamesZhutheThird

For someone who has this problem, I share my solution. please set the find_unused_parameters as True in the DDP. This problem has been triggered by the unused quantized vector. (Each DDP process waits other process until all grad of learnable model parameters are used)

if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True, # here
        )

@ekyy2, I hope this solution work for you.

subminu avatar Nov 09 '23 02:11 subminu

For someone who has this problem, I share my solution. please set the find_unused_parameters as True in the DDP. This problem has been triggered by the unused quantized vector. (Each DDP process waits other process until all grad of learnable model parameters are used)

if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True, # here
        )

@ekyy2, I hope this solution work for you.

I can confirm that this works. Thank you so much! The issue can be closed.

ekyy2 avatar Nov 10 '23 04:11 ekyy2

For someone who has this problem, I share my solution. please set the find_unused_parameters as True in the DDP. This problem has been triggered by the unused quantized vector. (Each DDP process waits other process until all grad of learnable model parameters are used)

if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True, # here
        )

@ekyy2, I hope this solution work for you.

still problem

peylnog avatar Mar 08 '24 13:03 peylnog