vq-vae-2-pytorch icon indicating copy to clipboard operation
vq-vae-2-pytorch copied to clipboard

Support for torch.cuda.amp in VQ-VAE training

Open vvvm23 opened this issue 4 years ago • 6 comments
trafficstars

Feature request for AMP support in VQ-VAE training. So far, I tried naively modifying the train function in train_vqvae.py like so:

#  ...
for i, (img, label) in enumerate(loader):
    model.zero_grad()

    img = img.to(device)

    with torch.cuda.amp.autocast():
        out, latent_loss = model(img)
        recon_loss = criterion(out, img)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
    scaler.scale(loss).backward()

    if scheduler is not None:
        scheduler.step()
    scaler.step(optimizer)
    scaler.update()
# ...

The MSE error appears normal, but the latent error becomes infinite. I'm going to try a few ideas when I have the time. I suspect that half precision and/or scaling doesn't play well with EMA updates. One "workaround" is to replace EMA with the 2nd term in the loss function in the original paper, so as to only update parameters using gradients, but that is far from ideal.

Thanks!

vvvm23 avatar Apr 28 '21 23:04 vvvm23

I think it will be safer to use fp32 for entire quantize operations.

rosinality avatar Apr 29 '21 11:04 rosinality

So, wrapping Quantize.forward in @torch.cuda.amp.autocast(enabled=False) and casting the buffers to be type torch.float32? Might also have to cast the input.

vvvm23 avatar Apr 29 '21 12:04 vvvm23

Yes. It may work.

rosinality avatar Apr 29 '21 14:04 rosinality

Okay! I can make a pull request for this if you want? If not, I can just close this.

vvvm23 avatar Apr 29 '21 14:04 vvvm23

If it is suffice to reproduct the result of fp32 training, definitely it would be nice to have.

rosinality avatar Apr 30 '21 15:04 rosinality

For some reason I can't improve forward pass speed under FP16. (maybe it is bottlenecked by FP32 in quantize operations?) Memory usage is improved though. I'll play around with this a little more and then maybe make a pull request.

vvvm23 avatar May 05 '21 19:05 vvvm23