Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
243fb59
Reduce RAM and compute time in model saving with Loras
rattus128 Dec 31, 2025
8fda2eb
ops: Do bias dtype conversion on compute stream
rattus128 Dec 22, 2025
f9a225b
mm: Implement cast buffer allocations
rattus128 Dec 23, 2025
2e22711
move string_to_seed to utils.py
rattus128 Jan 8, 2026
095478f
pinned_memory: add python
rattus128 Jan 13, 2026
b06534e
mp: wrap get_free_memory
rattus128 Jan 13, 2026
d2956bb
mp/mm: APi expansions for dynamic loading
rattus128 Jan 13, 2026
f00094a
mp: add mode for non comfy weight prioritization
rattus128 Jan 13, 2026
8fe566b
ops/mp: implement aimdo
rattus128 Jan 13, 2026
4ed6c6f
models: Use CoreModelPatcher
rattus128 Jan 13, 2026
bacd916
execution: add aimdo primary pytorch cache integration
rattus128 Jan 13, 2026
33583d9
main: Go live with --fast dynamic_vram
rattus128 Jan 13, 2026
2e1c266
mm: fix sync
rattus128 Jan 13, 2026
17cdb02
write better tx commentary
rattus128 Jan 13, 2026
645c459
add missing del on unpin
rattus128 Jan 13, 2026
5916464
misc cleanup
rattus128 Jan 13, 2026
307d25e
ruff
rattus128 Jan 13, 2026
28dd1c4
sd: empty cache on tiler fallback
rattus128 Jan 13, 2026
cb41b22
clip: support assign load when taking clip from a ckpt
rattus128 Jan 15, 2026
265ae3e
sampling: improve progress meter accuracy for dynamic loading
rattus128 Jan 15, 2026
b0d6f2a
main: Rework aimdo into process
rattus128 Jan 15, 2026
b5806c8
aimdo version bump
rattus128 Jan 15, 2026
82f388f
remove junk arg
rattus128 Jan 15, 2026
932b37d
ops: defer creation of the parameters until state dict load
rattus128 Jan 18, 2026
ec5a81c
implement lightweight safetensors with READ mmap
rattus128 Jan 18, 2026
d4f8950
execution: remove per node gc.collect()
rattus128 Jan 20, 2026
2adbbd6
mm: remove left over hooks draft code
rattus128 Jan 20, 2026
f93e09a
mp: handle blank __new__ call
rattus128 Jan 20, 2026
aef8d00
nodes_model_patch: fix copy-paste coding error
rattus128 Jan 20, 2026
96e5d45
ruff
rattus128 Jan 21, 2026
6e641d8
mp: big bump on the VBAR sizes
rattus128 Jan 21, 2026
4979c07
archive the model defined dtypes
rattus128 Jan 21, 2026
65b9729
ops: fix __init__ return
rattus128 Jan 21, 2026
2d96b2f
MPDynamic: Add support for model defined dtype
rattus128 Jan 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions comfy/audio_encoders/audio_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def __init__(self, config):
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

def get_sd(self):
return self.model.state_dict()
Expand Down
4 changes: 4 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
DynamicVRAM = "dynamic_vram"

parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))

Expand Down Expand Up @@ -257,3 +258,6 @@ def is_valid_directory(path: str) -> str:
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)

def enables_dynamic_vram():
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
4 changes: 2 additions & 2 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(self, json_config):
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

def get_sd(self):
return self.model.state_dict()
Expand Down
2 changes: 1 addition & 1 deletion comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())

self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling
Expand Down
33 changes: 32 additions & 1 deletion comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
import math
import time
from functools import partial

from scipy import integrate
import torch
from torch import nn
import torchsde
from tqdm.auto import trange, tqdm
from tqdm.auto import trange as trange_, tqdm

from . import utils
from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling

import comfy.memory_management


def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)

pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")

_update = pbar.update

def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")

_update(n)

pbar.update = warmup_update
return pbar


def append_zero(x):
return torch.cat([x, x.new_zeros([1])])

Expand Down
4 changes: 2 additions & 2 deletions comfy/ldm/hunyuan_video/upsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def __init__(self, model_type, config):
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())

def get_sd(self):
return self.model.state_dict()
Expand Down
81 changes: 81 additions & 0 deletions comfy/memory_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import math
import torch
from typing import NamedTuple

from comfy.quant_ops import QuantizedTensor

class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype

def element_size(self):
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
return info.bits // 8

def numel(self):
return math.prod(self.shape)

def tensors_to_geometries(tensors, dtype=None):
geometries = []
for t in tensors:
if t is None or isinstance(t, QuantizedTensor):
geometries.append(t)
continue
tdtype = t.dtype
if hasattr(t, "_model_dtype"):
tdtype = t._model_dtype
if dtype is not None:
tdtype = dtype
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
return geometries

def vram_aligned_size(tensor):
if isinstance(tensor, list):
return sum([vram_aligned_size(t) for t in tensor])

if isinstance(tensor, QuantizedTensor):
inner_tensors, _ = tensor.__tensor_flatten__()
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])

if tensor is None:
return 0

size = tensor.numel() * tensor.element_size()
aligment_req = 1024
return (size + aligment_req - 1) // aligment_req * aligment_req

def interpret_gathered_like(tensors, gathered):
offset = 0
dest_views = []

if gathered.dim() != 1 or gathered.element_size() != 1:
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")

for tensor in tensors:

if tensor is None:
dest_views.append(None)
continue

if isinstance(tensor, QuantizedTensor):
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
else:
templates = { "data": tensor }

actuals = {}
for attr, template in templates.items():
size = template.numel() * template.element_size()
if offset + size > gathered.numel():
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
offset += vram_aligned_size(template)

if isinstance(tensor, QuantizedTensor):
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
else:
dest_views.append(actuals["data"])

return dest_views

aimdo_allocator = None
15 changes: 7 additions & 8 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)

comfy.model_management.archive_model_dtypes(self.diffusion_model)

self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
Expand Down Expand Up @@ -298,15 +300,15 @@ def extra_conds(self, **kwargs):

return out

def load_model_weights(self, sd, unet_prefix=""):
def load_model_weights(self, sd, unet_prefix="", assign=False):
to_load = {}
keys = list(sd.keys())
for k in keys:
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)

to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))

Expand All @@ -321,18 +323,15 @@ def process_latent_in(self, latent):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)

def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
if vae_state_dict is not None:
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))

unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)

if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])

Expand Down Expand Up @@ -775,8 +774,8 @@ def extra_conds(self, **kwargs):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
for k in d:
s = d[k]
Expand Down
Loading
Loading