Fix Z-Image FP16 overflow via downscaling
This PR improves FP16 stability for Z-Image by using scaling instead of clamping.
Because the tensor passes through Linear and RMSNorm, the fp16 tensor can be scaled down to prevent overflow.
The scale value(2^x) is based on testing.
No noticeable impact on inference speed.
Tested with: Z-Image and Lumina 2.
The clamp_fp16 function can be safely removed or stay just in case.
workflow.json
this will get looked at and potentially merged after next stable!
Does this fix the fried/low-quality images with lumina 2 in fp16?
The reason I don't enable fp16 in lumina2 is because the neta yume 3.5 model breaks with fp16 + my clamping. It also breaks using the downscaling in this PR.
@vanDuven could you remove the fp16 from supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]? At that point comfy will be happy to test.
Afterwards, if do manage to make neta yuma 3.5 work with fp16, then that can be a separate PR.
The reason I don't enable fp16 in lumina2 is because the neta yume 3.5 model breaks with fp16 + my clamping. It also breaks using the downscaling in this PR.
Someone made a patch that fixes the issue with lumina, allowing it to work in fp16 just fine. But after the hack, the patch does not work at all anymore. Making lumina image unusable on pre 30 series gpu's.
In case you are curious, this is the patch for lumina made by a user named "reakaakasky" on civit.
from comfy.ldm.lumina.model import JointTransformerBlock
import torch
from typing import Optional
import logging
logging.info("patching Lumina 2 JointTransformerBlock")
if ENABLE_TORCH_COMPILE:
JointTransformerBlock.forward = torch.compile(JointTransformerBlock.forward)
# Patch to support fp16
def forward_with_fp32_fallback(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
transformer_options={},
):
dtype = x.dtype
out = self._forward(x, x_mask, freqs_cis, adaln_input, transformer_options)
if x.dtype == torch.float16 and x.is_cuda:
isinf, isnan = out.isinf().any(), out.isnan().any()
if isinf or isnan:
# print(f"inf {isinf}, nan {isnan}")
with torch.amp.autocast_mode.autocast("cuda", torch.float32):
out = self._forward(
x, x_mask, freqs_cis, adaln_input, transformer_options
)
# print(f"fixed out: dtype {out.dtype}, max {out.abs().max().item()}")
out = out.to(dtype).nan_to_num()
return out
if ENABLE_F32_FALLBACK:
JointTransformerBlock._forward = JointTransformerBlock.forward
JointTransformerBlock.forward = forward_with_fp32_fallback