ComfyUI
ComfyUI copied to clipboard
dequantization offload accounting (fixes Flux2 OOMs - incl TEs)
This is the hopefully full root cause fix on:
https://github.com/comfyanonymous/ComfyUI/issues/10891
Primary commit message:
commit 53bd09926cf0f680d0fd67afcb2d0a289d71940d
Author: Rattus <[email protected]>
Date: Sun Dec 7 21:23:05 2025 +1000
Account for dequantization and type-casts in offload costs
When measuring the cost of offload, identify weights that need a type
change or dequantization and add the size of the conversion result
to the offload cost.
This is mutually exclusive with lowvram patches which already has
a large conservative estimate and wont overlap the dequant cost so
dont double count.
Example Test case:
RTX3060 Flux2 workflow with ModelComputeDtype node set to fp32
Before:
Requested to load Flux2TEModel_
loaded partially; 10508.42 MB usable, 10025.59 MB loaded, 7155.01 MB offloaded, 480.00 MB buffer reserved, lowvram patches: 0
!!! Exception during processing !!! Allocation on device
...
x = self.mlp(x)
^^^^^^^^^^^
File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rattus/ComfyUI/comfy/text_encoders/llama.py", line 327, in forward
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rattus/ComfyUI/comfy/ops.py", line 608, in forward
return self.forward_comfy_cast_weights(input, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rattus/ComfyUI/comfy/ops.py", line 599, in forward_comfy_cast_weights
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rattus/ComfyUI/comfy/ops.py", line 127, in cast_bias_weight
weight = weight.dequantize()
^^^^^^^^^^^^^^^^^^^
File "/home/rattus/ComfyUI/comfy/quant_ops.py", line 197, in dequantize
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rattus/ComfyUI/comfy/quant_ops.py", line 431, in dequantize
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rattus/venv2/lib/python3.12/site-packages/torch/_ops.py", line 841, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: Allocation on device
Got an OOM, unloading all loaded models.
Prompt executed in 7.47 seconds
After:
Confirmed that it resolves https://github.com/comfyanonymous/ComfyUI/issues/10891#issuecomment-3621595666.
!!! Exception during processing !!! 'RMSNorm' object has no attribute 'comfy_cast_weights'
Traceback (most recent call last):
File "/home/stable_diff/ComfyUI/execution.py", line 515, in execute
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stable_diff/ComfyUI/execution.py", line 329, in get_output_data
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stable_diff/ComfyUI/execution.py", line 303, in _async_map_node_over_list
await process_inputs(input_dict, i)
File "/home/stable_diff/ComfyUI/execution.py", line 291, in process_inputs
result = f(**inputs)
File "/home/stable_diff/ComfyUI/nodes.py", line 77, in encode
return (clip.encode_from_tokens_scheduled(tokens), )
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^
File "/home/stable_diff/ComfyUI/comfy/sd.py", line 205, in encode_from_tokens_scheduled
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
File "/home/stable_diff/ComfyUI/comfy/sd.py", line 267, in encode_from_tokens
self.load_model()
~~~~~~~~~~~~~~~^^
File "/home/stable_diff/ComfyUI/comfy/sd.py", line 301, in load_model
model_management.load_model_gpu(self.patcher)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 706, in load_model_gpu
return load_models_gpu([model])
File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 701, in load_models_gpu
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 506, in model_load
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 536, in model_use_more_vram
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stable_diff/ComfyUI/comfy/model_patcher.py", line 965, in partially_load
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/stable_diff/ComfyUI/comfy/model_patcher.py", line 934, in partially_unload
m.prev_comfy_cast_weights = m.comfy_cast_weights
^^^^^^^^^^^^^^^^^^^^
File "/home/stable_diff/.local/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1964, in __getattr__
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
AttributeError: 'RMSNorm' object has no attribute 'comfy_cast_weights'. Did you mean: 'comfy_patched_weights'?
I get this when running the basic hidream dev workflow on: https://comfyanonymous.github.io/ComfyUI_examples/hidream/ with simulated 16GB vram.
Nevermind I think it was my fault: https://github.com/comfyanonymous/ComfyUI/pull/11201