vq-vae-2-pytorch
vq-vae-2-pytorch copied to clipboard
Support for torch.cuda.amp in VQ-VAE training
Feature request for AMP support in VQ-VAE training.
So far, I tried naively modifying the train function in train_vqvae.py like so:
# ...
for i, (img, label) in enumerate(loader):
model.zero_grad()
img = img.to(device)
with torch.cuda.amp.autocast():
out, latent_loss = model(img)
recon_loss = criterion(out, img)
latent_loss = latent_loss.mean()
loss = recon_loss + latent_loss_weight * latent_loss
scaler.scale(loss).backward()
if scheduler is not None:
scheduler.step()
scaler.step(optimizer)
scaler.update()
# ...
The MSE error appears normal, but the latent error becomes infinite. I'm going to try a few ideas when I have the time. I suspect that half precision and/or scaling doesn't play well with EMA updates. One "workaround" is to replace EMA with the 2nd term in the loss function in the original paper, so as to only update parameters using gradients, but that is far from ideal.
Thanks!
I think it will be safer to use fp32 for entire quantize operations.
So, wrapping Quantize.forward in @torch.cuda.amp.autocast(enabled=False) and casting the buffers to be type torch.float32? Might also have to cast the input.
Yes. It may work.
Okay! I can make a pull request for this if you want? If not, I can just close this.
If it is suffice to reproduct the result of fp32 training, definitely it would be nice to have.
For some reason I can't improve forward pass speed under FP16. (maybe it is bottlenecked by FP32 in quantize operations?) Memory usage is improved though. I'll play around with this a little more and then maybe make a pull request.