ComfyUI icon indicating copy to clipboard operation
ComfyUI copied to clipboard

dequantization offload accounting (fixes Flux2 OOMs - incl TEs)

Open rattus128 opened this issue 1 month ago • 1 comments

This is the hopefully full root cause fix on:

https://github.com/comfyanonymous/ComfyUI/issues/10891

Primary commit message:

commit 53bd09926cf0f680d0fd67afcb2d0a289d71940d
Author: Rattus <[email protected]>
Date:   Sun Dec 7 21:23:05 2025 +1000

    Account for dequantization and type-casts in offload costs
    
    When measuring the cost of offload, identify weights that need a type
    change or dequantization and add the size of the conversion result
    to the offload cost.
    
    This is mutually exclusive with lowvram patches which already has
    a large conservative estimate and wont overlap the dequant cost so
    dont double count.

Example Test case:

RTX3060 Flux2 workflow with ModelComputeDtype node set to fp32

Before:

Requested to load Flux2TEModel_
loaded partially; 10508.42 MB usable, 10025.59 MB loaded, 7155.01 MB offloaded, 480.00 MB buffer reserved, lowvram patches: 0
!!! Exception during processing !!! Allocation on device 
...
    x = self.mlp(x)
        ^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/text_encoders/llama.py", line 327, in forward
    return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/ops.py", line 608, in forward
    return self.forward_comfy_cast_weights(input, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/ops.py", line 599, in forward_comfy_cast_weights
    weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/ops.py", line 127, in cast_bias_weight
    weight = weight.dequantize()
             ^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/quant_ops.py", line 197, in dequantize
    return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/quant_ops.py", line 431, in dequantize
    plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/_ops.py", line 841, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: Allocation on device 

Got an OOM, unloading all loaded models.
Prompt executed in 7.47 seconds

After:

image

rattus128 avatar Dec 07 '25 11:12 rattus128

Confirmed that it resolves https://github.com/comfyanonymous/ComfyUI/issues/10891#issuecomment-3621595666.

Balladie avatar Dec 07 '25 13:12 Balladie

!!! Exception during processing !!! 'RMSNorm' object has no attribute 'comfy_cast_weights'
Traceback (most recent call last):
  File "/home/stable_diff/ComfyUI/execution.py", line 515, in execute
    output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
                                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/execution.py", line 329, in get_output_data
    return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/execution.py", line 303, in _async_map_node_over_list
    await process_inputs(input_dict, i)
  File "/home/stable_diff/ComfyUI/execution.py", line 291, in process_inputs
    result = f(**inputs)
  File "/home/stable_diff/ComfyUI/nodes.py", line 77, in encode
    return (clip.encode_from_tokens_scheduled(tokens), )
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/sd.py", line 205, in encode_from_tokens_scheduled
    pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
  File "/home/stable_diff/ComfyUI/comfy/sd.py", line 267, in encode_from_tokens
    self.load_model()
    ~~~~~~~~~~~~~~~^^
  File "/home/stable_diff/ComfyUI/comfy/sd.py", line 301, in load_model
    model_management.load_model_gpu(self.patcher)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 706, in load_model_gpu
    return load_models_gpu([model])
  File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 701, in load_models_gpu
    loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
    ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 506, in model_load
    self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 536, in model_use_more_vram
    return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
           ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_patcher.py", line 965, in partially_load
    self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
    ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_patcher.py", line 934, in partially_unload
    m.prev_comfy_cast_weights = m.comfy_cast_weights
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/.local/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1964, in __getattr__
    raise AttributeError(
        f"'{type(self).__name__}' object has no attribute '{name}'"
    )
AttributeError: 'RMSNorm' object has no attribute 'comfy_cast_weights'. Did you mean: 'comfy_patched_weights'?

I get this when running the basic hidream dev workflow on: https://comfyanonymous.github.io/ComfyUI_examples/hidream/ with simulated 16GB vram.

comfyanonymous avatar Dec 09 '25 04:12 comfyanonymous

Nevermind I think it was my fault: https://github.com/comfyanonymous/ComfyUI/pull/11201

comfyanonymous avatar Dec 09 '25 04:12 comfyanonymous