imagen-pytorch
imagen-pytorch copied to clipboard
Unexpectedly high fp16 memory usage
I've been noticing that using fp16 has not resulted in much difference in model size or memory usage. Using the below script (taken from your docs directly) and only changing the flag fp16=True
to False
yields a difference of 4% VRAM usage and exactly the same checkpoint size for both.
This seems suspiciously small compared to other projects I've used with fp16 enabled. And a few people on the LAION discord Imagen channel are noticing the same thing. Although others seem to notice a bigger difference as well.
Wondering if it could be a difference of training scripts, since we all seem to be using our own custom variations.
import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer
unet1 = Unet(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = 3,
layer_attns = (False, True, True),
layer_cross_attns = False,
use_linear_attn = True
)
unet2 = SRUnet256(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = (2, 4, 8),
layer_attns = (False, False, True),
layer_cross_attns = False
)
imagen = Imagen(
condition_on_text = False,
unets = (unet1, unet2),
image_sizes = (64, 128),
timesteps = 1000
)
trainer = ImagenTrainer(
imagen,
fp16=False #change this and compare model sizes/memory usage
).cuda()
training_images = torch.randn(4, 3, 256, 256).cuda()
for i in range(100):
loss = trainer(training_images, unet_number = 1)
trainer.update(unet_number = 1)
trainer.save("./checkpoint.pt")
@trufty Hi Trufty
Which version of Imagen are you on? I recently made a change that may have inadvertently casted everything to float32, but this commit should have fixed it https://github.com/lucidrains/imagen-pytorch/commit/afed17a8ec724608ffdcc0a33c1ee68718c36b01
I'm testing with the latest version 1.9.6 since I saw your change. But I also noticed the same behavior on 1.7.x as well.
@trufty ohh, i'm not sure what's going on then
i'll dig into it next week
Just to confirm I'm not going crazy, I interrogated the fp16 = True
model from the script above, and the dtype of all layers are float32 😢
unets.0.final_conv.weight | torch.float32
unets.0.final_conv.bias | torch.float32
unets.1.null_text_embed | torch.float32
unets.1.null_text_hidden | torch.float32
...
@trufty let me know if https://github.com/lucidrains/imagen-pytorch/commit/818ab5da6ae45edbb4b407ec2b60f3314dbae10c works
Using imagen-pytorch==1.10.0
I'm still getting all float32 model layers with the above scrip with fp16 = True
I verified the install version with pip list
and deleted the existing checkpoint before testing.
Are you able to get anything other than float32 weights with the same script? It it just my docker environment?
@trufty yea, check your environment again, i'm seeing the unet output float16 properly
Yea, If its working for you, I has to be a local env issue... uggh. Thanks for helping so far. (and yea I had the fp16 flag set correctly)
@trufty the weights will be kept in float32 when doing mixed precision training, and the autocast takes care of auto converting between float32 and float16
check the memory usage again
4590 MB vs 4388 MB is the difference I see which still matches the 4% I mentioned earlier. I just expected a much larger difference than that.
I ended up training Unet1 separately from Unet2 and then splicing the models together as a workaround.