imagen-pytorch
imagen-pytorch copied to clipboard
Any specific reason sampling is not in FP16?
During training the forward method casts to FP16 but during sampling no
@torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs):
self.print_untrained_unets()
if not self.is_main:
kwargs["use_tqdm"] = False
output = self.imagen.sample(*args, device=self.device, **kwargs)
return output
@partial(cast_torch_tensor, cast_fp16=True)
def forward(self, *args, unet_number=None, **kwargs):
unet_number = self.validate_unet_number(unet_number)
self.validate_and_set_unet_being_trained(unet_number)
self.set_accelerator_scaler(unet_number)
assert (
not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number
), f"you can only train unet #{self.only_train_unet_number}"
with self.accelerator.accumulate(self.unet_being_trained):
with self.accelerator.autocast():
loss = self.imagen(*args, unet=self.unet_being_trained, unet_number=unet_number, **kwargs)
if self.training:
self.accelerator.backward(loss)
return loss
I tried casting to FP16 and something in the loop changes to float32
even if the inputs are float16
I wonder if you have already encountered that and if that's the reason there's no casting to FP16 during sampling
Best regards and thanks for the great repo,
how to run this code, please provide the step for text to image , please
When training neural networks, especially large models, it's common to use mixed precision training to save memory and speed up computations. This involves using FP16 (half-precision) for certain operations while retaining FP32 (full-precision) for others where higher precision is necessary. The code snippet you provided shows the use of mixed precision during training but not during sampling. This is because sampling (inference) typically doesn't require the same precision optimization as training, but it can still benefit from FP16 for memory efficiency.
To ensure that your sampling loop also benefits from FP16 precision, you can add the necessary casting.
import torch
@torch.no_grad()
@partial(cast_torch_tensor, cast_fp16=True)
def sample(self, *args, **kwargs):
self.print_untrained_unets()
if not self.is_main:
kwargs["use_tqdm"] = False
with torch.cuda.amp.autocast(): # Use autocast for FP16 inference
output = self.imagen.sample(*args, device=self.device, **kwargs)
return output
simplified code for cast_torch_tensor decorator that supports FP16 casting:
from functools import wraps
def cast_torch_tensor(func=None, cast_fp16=False):
@wraps(func)
def wrapper(*args, **kwargs):
args = tuple(arg.half() if cast_fp16 and isinstance(arg, torch.Tensor) else arg for arg in args)
kwargs = {k: v.half() if cast_fp16 and isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)
return wrapper
Hope this helps, Thanks