ComfyUI icon indicating copy to clipboard operation
ComfyUI copied to clipboard

Fix Z-Image FP16 overflow via downscaling

Open vanDuven opened this issue 1 month ago • 5 comments

This PR improves FP16 stability for Z-Image by using scaling instead of clamping.

Because the tensor passes through Linear and RMSNorm, the fp16 tensor can be scaled down to prevent overflow. The scale value(2^x) is based on testing. No noticeable impact on inference speed. Tested with: Z-Image and Lumina 2.

The clamp_fp16 function can be safely removed or stay just in case. workflow.json comparison_result

vanDuven avatar Dec 08 '25 13:12 vanDuven

this will get looked at and potentially merged after next stable!

Kosinkadink avatar Dec 09 '25 04:12 Kosinkadink

Does this fix the fried/low-quality images with lumina 2 in fp16?

gelukuMLG avatar Dec 10 '25 13:12 gelukuMLG

The reason I don't enable fp16 in lumina2 is because the neta yume 3.5 model breaks with fp16 + my clamping. It also breaks using the downscaling in this PR.

comfyanonymous avatar Dec 10 '25 23:12 comfyanonymous

@vanDuven could you remove the fp16 from supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]? At that point comfy will be happy to test.

Afterwards, if do manage to make neta yuma 3.5 work with fp16, then that can be a separate PR.

Kosinkadink avatar Dec 11 '25 02:12 Kosinkadink

The reason I don't enable fp16 in lumina2 is because the neta yume 3.5 model breaks with fp16 + my clamping. It also breaks using the downscaling in this PR.

Someone made a patch that fixes the issue with lumina, allowing it to work in fp16 just fine. But after the hack, the patch does not work at all anymore. Making lumina image unusable on pre 30 series gpu's.

In case you are curious, this is the patch for lumina made by a user named "reakaakasky" on civit.

from comfy.ldm.lumina.model import JointTransformerBlock
import torch
from typing import Optional
import logging

logging.info("patching Lumina 2 JointTransformerBlock")

if ENABLE_TORCH_COMPILE:
    JointTransformerBlock.forward = torch.compile(JointTransformerBlock.forward)

# Patch to support fp16
def forward_with_fp32_fallback(
    self,
    x: torch.Tensor,
    x_mask: torch.Tensor,
    freqs_cis: torch.Tensor,
    adaln_input: Optional[torch.Tensor] = None,
    transformer_options={},
):
    dtype = x.dtype
    out = self._forward(x, x_mask, freqs_cis, adaln_input, transformer_options)
    if x.dtype == torch.float16 and x.is_cuda:
        isinf, isnan = out.isinf().any(), out.isnan().any()
        if isinf or isnan:
            # print(f"inf {isinf}, nan {isnan}")
            with torch.amp.autocast_mode.autocast("cuda", torch.float32):
                out = self._forward(
                    x, x_mask, freqs_cis, adaln_input, transformer_options
                )
            # print(f"fixed out: dtype {out.dtype}, max {out.abs().max().item()}")
            out = out.to(dtype).nan_to_num()
    return out

if ENABLE_F32_FALLBACK:
    JointTransformerBlock._forward = JointTransformerBlock.forward
    JointTransformerBlock.forward = forward_with_fp32_fallback

gelukuMLG avatar Dec 11 '25 11:12 gelukuMLG