Skip to content

Commit e136b6d

Browse files
authored
dequantization offload accounting (fixes Flux2 OOMs - incl TEs) (#11171)
* make setattr safe for non existent attributes Handle the case where the attribute doesnt exist by returning a static sentinel (distinct from None). If the sentinel is passed in as the set value, del the attr. * Account for dequantization and type-casts in offload costs When measuring the cost of offload, identify weights that need a type change or dequantization and add the size of the conversion result to the offload cost. This is mutually exclusive with lowvram patches which already has a large conservative estimate and wont overlap the dequant cost so\ dont double count. * Set the compute type on CLIP MPs So that the loader can know the size of weights for dequant accounting.
1 parent d50f342 commit e136b6d

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

comfy/model_patcher.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import comfy.patcher_extension
3636
import comfy.utils
3737
from comfy.comfy_types import UnetWrapperFunction
38+
from comfy.quant_ops import QuantizedTensor
3839
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
3940

4041

@@ -665,12 +666,18 @@ def _load_list(self):
665666
module_mem = comfy.model_management.module_size(m)
666667
module_offload_mem = module_mem
667668
if hasattr(m, "comfy_cast_weights"):
668-
weight_key = "{}.weight".format(n)
669-
bias_key = "{}.bias".format(n)
670-
if weight_key in self.patches:
671-
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
672-
if bias_key in self.patches:
673-
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
669+
def check_module_offload_mem(key):
670+
if key in self.patches:
671+
return low_vram_patch_estimate_vram(self.model, key)
672+
model_dtype = getattr(self.model, "manual_cast_dtype", None)
673+
weight, _, _ = get_key_weight(self.model, key)
674+
if model_dtype is None or weight is None:
675+
return 0
676+
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
677+
return weight.numel() * model_dtype.itemsize
678+
return 0
679+
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
680+
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
674681
loading.append((module_offload_mem, module_mem, n, m, params))
675682
return loading
676683

comfy/sd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def __init__(self, target=None, embedding_directory=None, no_init=False, tokeniz
127127

128128
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
129129
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
130+
#Match torch.float32 hardcode upcast in TE implemention
131+
self.patcher.set_model_compute_dtype(torch.float32)
130132
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
131133
self.patcher.is_clip = True
132134
self.apply_hooks_to_conds = None

comfy/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
803803
return None
804804
return f.read(length_of_header)
805805

806+
ATTR_UNSET={}
807+
806808
def set_attr(obj, attr, value):
807809
attrs = attr.split(".")
808810
for name in attrs[:-1]:
809811
obj = getattr(obj, name)
810-
prev = getattr(obj, attrs[-1])
811-
setattr(obj, attrs[-1], value)
812+
prev = getattr(obj, attrs[-1], ATTR_UNSET)
813+
if value is ATTR_UNSET:
814+
delattr(obj, attrs[-1])
815+
else:
816+
setattr(obj, attrs[-1], value)
812817
return prev
813818

814819
def set_attr_param(obj, attr, value):

0 commit comments

Comments
 (0)