From 953574117da9f6961ccb6b4903ffd0186d90ce15 Mon Sep 17 00:00:00 2001 From: Jash Shah Date: Thu, 26 Feb 2026 11:58:47 -0800 Subject: [PATCH] Fix None dereference crashes and mutable default arguments In model_loading_utils.py, chained .get().get() would crash with AttributeError if the first key was missing from the remapping dict. In autoencoder_kl_cosmos.py, _WAVELETS.get(patch_method).clone() crashes with AttributeError if patch_method is not a recognized wavelet type (both CosmosPatchEmbed3d and CosmosUnpatcher3d). In autoencoder_kl_wan.py, 8 forward() methods used feat_idx=[0] as a mutable default argument. Since the list is mutated via feat_idx[0] += 1 during forward passes, the shared default list accumulates state across calls, corrupting the cache index. --- .../autoencoders/autoencoder_kl_cosmos.py | 10 ++++-- .../models/autoencoders/autoencoder_kl_wan.py | 32 ++++++++++++++----- src/diffusers/models/model_loading_utils.py | 2 +- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py index 4fe1f62890be..11305b00c05b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py @@ -111,7 +111,10 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None: self.patch_size = patch_size self.patch_method = patch_method - wavelets = _WAVELETS.get(patch_method).clone() + wavelets = _WAVELETS.get(patch_method) + if wavelets is None: + raise ValueError(f"Unknown patch_method '{patch_method}'. Supported methods: {list(_WAVELETS.keys())}") + wavelets = wavelets.clone() arange = torch.arange(wavelets.shape[0]) self.register_buffer("wavelets", wavelets, persistent=False) @@ -191,7 +194,10 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar"): self.patch_size = patch_size self.patch_method = patch_method - wavelets = _WAVELETS.get(patch_method).clone() + wavelets = _WAVELETS.get(patch_method) + if wavelets is None: + raise ValueError(f"Unknown patch_method '{patch_method}'. Supported methods: {list(_WAVELETS.keys())}") + wavelets = wavelets.clone() arange = torch.arange(wavelets.shape[0]) self.register_buffer("wavelets", wavelets, persistent=False) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index ea5d2efe642f..a9121cff50f3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -259,7 +259,9 @@ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] b, c, t, h, w = x.size() if self.mode == "upsample3d": if feat_cache is not None: @@ -336,7 +338,9 @@ def __init__( self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1) self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] # Apply shortcut connection h = self.conv_shortcut(x) @@ -449,7 +453,9 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] # First residual block x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx) @@ -489,7 +495,9 @@ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample else: self.downsampler = None - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] x_copy = x.clone() for resnet in self.resnets: x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) @@ -580,7 +588,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() @@ -677,7 +687,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): """ Forward pass through the upsampling block. @@ -689,6 +699,8 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): Returns: torch.Tensor: Output tensor """ + if feat_idx is None: + feat_idx = [0] x_copy = x.clone() for resnet in self.resnets: @@ -752,7 +764,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=None): """ Forward pass through the upsampling block. @@ -764,6 +776,8 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): Returns: torch.Tensor: Output tensor """ + if feat_idx is None: + feat_idx = [0] for resnet in self.resnets: if feat_cache is not None: x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) @@ -869,7 +883,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): + if feat_idx is None: + feat_idx = [0] ## conv1 if feat_cache is not None: idx = feat_idx[0] diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 04642ad5d401..f20a0d9a99f1 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -117,7 +117,7 @@ def _determine_device_map( def _fetch_remapped_cls_from_config(config, old_class): previous_class_name = old_class.__name__ - remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None) + remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name, {}).get(config["norm_type"], None) # Details: # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818