imagen-pytorch icon indicating copy to clipboard operation
imagen-pytorch copied to clipboard

Unexpectedly high fp16 memory usage

Open trufty opened this issue 2 years ago • 11 comments

I've been noticing that using fp16 has not resulted in much difference in model size or memory usage. Using the below script (taken from your docs directly) and only changing the flag fp16=True to False yields a difference of 4% VRAM usage and exactly the same checkpoint size for both.

This seems suspiciously small compared to other projects I've used with fp16 enabled. And a few people on the LAION discord Imagen channel are noticing the same thing. Although others seem to notice a bigger difference as well.

Wondering if it could be a difference of training scripts, since we all seem to be using our own custom variations.

import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer

unet1 = Unet(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True),
    layer_cross_attns = False,
    use_linear_attn = True
)

unet2 = SRUnet256(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = (2, 4, 8),
    layer_attns = (False, False, True),
    layer_cross_attns = False
)

imagen = Imagen(
    condition_on_text = False,
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    timesteps = 1000
)

trainer = ImagenTrainer(
    imagen,
    fp16=False #change this and compare model sizes/memory usage
).cuda()

training_images = torch.randn(4, 3, 256, 256).cuda()


for i in range(100):
    loss = trainer(training_images, unet_number = 1)
    trainer.update(unet_number = 1)
    
trainer.save("./checkpoint.pt")

trufty avatar Aug 21 '22 18:08 trufty

@trufty Hi Trufty

Which version of Imagen are you on? I recently made a change that may have inadvertently casted everything to float32, but this commit should have fixed it https://github.com/lucidrains/imagen-pytorch/commit/afed17a8ec724608ffdcc0a33c1ee68718c36b01

lucidrains avatar Aug 21 '22 18:08 lucidrains

I'm testing with the latest version 1.9.6 since I saw your change. But I also noticed the same behavior on 1.7.x as well.

trufty avatar Aug 21 '22 19:08 trufty

@trufty ohh, i'm not sure what's going on then

i'll dig into it next week

lucidrains avatar Aug 21 '22 21:08 lucidrains

Just to confirm I'm not going crazy, I interrogated the fp16 = True model from the script above, and the dtype of all layers are float32 😢

unets.0.final_conv.weight    | torch.float32
unets.0.final_conv.bias    | torch.float32
unets.1.null_text_embed    | torch.float32
unets.1.null_text_hidden    | torch.float32
...

trufty avatar Aug 22 '22 14:08 trufty

@trufty let me know if https://github.com/lucidrains/imagen-pytorch/commit/818ab5da6ae45edbb4b407ec2b60f3314dbae10c works

lucidrains avatar Aug 22 '22 17:08 lucidrains

Using imagen-pytorch==1.10.0 I'm still getting all float32 model layers with the above scrip with fp16 = True I verified the install version with pip list and deleted the existing checkpoint before testing.

Are you able to get anything other than float32 weights with the same script? It it just my docker environment?

trufty avatar Aug 22 '22 17:08 trufty

@trufty yea, check your environment again, i'm seeing the unet output float16 properly

lucidrains avatar Aug 22 '22 18:08 lucidrains

Yea, If its working for you, I has to be a local env issue... uggh. Thanks for helping so far. (and yea I had the fp16 flag set correctly)

trufty avatar Aug 22 '22 18:08 trufty

@trufty the weights will be kept in float32 when doing mixed precision training, and the autocast takes care of auto converting between float32 and float16

check the memory usage again

lucidrains avatar Aug 22 '22 18:08 lucidrains

4590 MB vs 4388 MB is the difference I see which still matches the 4% I mentioned earlier. I just expected a much larger difference than that.

trufty avatar Aug 22 '22 20:08 trufty

I ended up training Unet1 separately from Unet2 and then splicing the models together as a workaround.

trufty avatar Sep 07 '22 16:09 trufty