Skip to content

Conversation

@rattus128
Copy link
Contributor

This is the hopefully full root cause fix on:

#10891

Primary commit message:

commit 53bd09926cf0f680d0fd67afcb2d0a289d71940d
Author: Rattus <rattus128@gmail.com>
Date:   Sun Dec 7 21:23:05 2025 +1000

    Account for dequantization and type-casts in offload costs
    
    When measuring the cost of offload, identify weights that need a type
    change or dequantization and add the size of the conversion result
    to the offload cost.
    
    This is mutually exclusive with lowvram patches which already has
    a large conservative estimate and wont overlap the dequant cost so
    dont double count.

Example Test case:

RTX3060 Flux2 workflow with ModelComputeDtype node set to fp32

Before:

Requested to load Flux2TEModel_
loaded partially; 10508.42 MB usable, 10025.59 MB loaded, 7155.01 MB offloaded, 480.00 MB buffer reserved, lowvram patches: 0
!!! Exception during processing !!! Allocation on device 
...
    x = self.mlp(x)
        ^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/text_encoders/llama.py", line 327, in forward
    return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/ops.py", line 608, in forward
    return self.forward_comfy_cast_weights(input, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/ops.py", line 599, in forward_comfy_cast_weights
    weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/ops.py", line 127, in cast_bias_weight
    weight = weight.dequantize()
             ^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/quant_ops.py", line 197, in dequantize
    return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/ComfyUI/comfy/quant_ops.py", line 431, in dequantize
    plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rattus/venv2/lib/python3.12/site-packages/torch/_ops.py", line 841, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: Allocation on device 

Got an OOM, unloading all loaded models.
Prompt executed in 7.47 seconds

After:

image

@Balladie
Copy link
Contributor

Balladie commented Dec 7, 2025

Confirmed that it resolves #10891 (comment).

comfy/sd.py Outdated

self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher.set_model_compute_dtype(dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slightly wrong because the text encoder is always upcasted to fp32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed to match:

--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -127,7 +127,8 @@ class CLIP:
 
         self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
         self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
-        self.patcher.set_model_compute_dtype(dtype)
+        #Match torch.float32 hardcode upcast in TE implemention
+        self.patcher.set_model_compute_dtype(torch.float32)
         self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
         self.patcher.is_clip = True
         self.apply_hooks_to_conds = None

Handle the case where the attribute doesnt exist by returning a static
sentinel (distinct from None). If the sentinel is passed in as the set
value, del the attr.
When measuring the cost of offload, identify weights that need a type
change or dequantization and add the size of the conversion result
to the offload cost.

This is mutually exclusive with lowvram patches which already has
a large conservative estimate and wont overlap the dequant cost so\
dont double count.
So that the loader can know the size of weights for dequant accounting.
@rattus128 rattus128 force-pushed the prs/dequant-reserve branch from a03d98e to 9663183 Compare December 8, 2025 23:16
@comfyanonymous
Copy link
Owner

!!! Exception during processing !!! 'RMSNorm' object has no attribute 'comfy_cast_weights'
Traceback (most recent call last):
  File "/home/stable_diff/ComfyUI/execution.py", line 515, in execute
    output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
                                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/execution.py", line 329, in get_output_data
    return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/execution.py", line 303, in _async_map_node_over_list
    await process_inputs(input_dict, i)
  File "/home/stable_diff/ComfyUI/execution.py", line 291, in process_inputs
    result = f(**inputs)
  File "/home/stable_diff/ComfyUI/nodes.py", line 77, in encode
    return (clip.encode_from_tokens_scheduled(tokens), )
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/sd.py", line 205, in encode_from_tokens_scheduled
    pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
  File "/home/stable_diff/ComfyUI/comfy/sd.py", line 267, in encode_from_tokens
    self.load_model()
    ~~~~~~~~~~~~~~~^^
  File "/home/stable_diff/ComfyUI/comfy/sd.py", line 301, in load_model
    model_management.load_model_gpu(self.patcher)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 706, in load_model_gpu
    return load_models_gpu([model])
  File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 701, in load_models_gpu
    loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
    ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 506, in model_load
    self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_management.py", line 536, in model_use_more_vram
    return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
           ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_patcher.py", line 965, in partially_load
    self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
    ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/ComfyUI/comfy/model_patcher.py", line 934, in partially_unload
    m.prev_comfy_cast_weights = m.comfy_cast_weights
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/stable_diff/.local/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1964, in __getattr__
    raise AttributeError(
        f"'{type(self).__name__}' object has no attribute '{name}'"
    )
AttributeError: 'RMSNorm' object has no attribute 'comfy_cast_weights'. Did you mean: 'comfy_patched_weights'?

I get this when running the basic hidream dev workflow on: https://comfyanonymous.github.io/ComfyUI_examples/hidream/ with simulated 16GB vram.

@comfyanonymous
Copy link
Owner

Nevermind I think it was my fault: #11201

@comfyanonymous comfyanonymous merged commit e136b6d into comfyanonymous:master Dec 9, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants