diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py index 4fe1f62890be..11305b00c05b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py @@ -111,7 +111,10 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None: self.patch_size = patch_size self.patch_method = patch_method - wavelets = _WAVELETS.get(patch_method).clone() + wavelets = _WAVELETS.get(patch_method) + if wavelets is None: + raise ValueError(f"Unknown patch_method '{patch_method}'. Supported methods: {list(_WAVELETS.keys())}") + wavelets = wavelets.clone() arange = torch.arange(wavelets.shape[0]) self.register_buffer("wavelets", wavelets, persistent=False) @@ -191,7 +194,10 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar"): self.patch_size = patch_size self.patch_method = patch_method - wavelets = _WAVELETS.get(patch_method).clone() + wavelets = _WAVELETS.get(patch_method) + if wavelets is None: + raise ValueError(f"Unknown patch_method '{patch_method}'. Supported methods: {list(_WAVELETS.keys())}") + wavelets = wavelets.clone() arange = torch.arange(wavelets.shape[0]) self.register_buffer("wavelets", wavelets, persistent=False) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index ea5d2efe642f..a9121cff50f3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -259,7 +259,9 @@ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] b, c, t, h, w = x.size() if self.mode == "upsample3d": if feat_cache is not None: @@ -336,7 +338,9 @@ def __init__( self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1) self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] # Apply shortcut connection h = self.conv_shortcut(x) @@ -449,7 +453,9 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] # First residual block x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx) @@ -489,7 +495,9 @@ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample else: self.downsampler = None - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] x_copy = x.clone() for resnet in self.resnets: x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) @@ -580,7 +588,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() @@ -677,7 +687,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): """ Forward pass through the upsampling block. @@ -689,6 +699,8 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): Returns: torch.Tensor: Output tensor """ + if feat_idx is None: + feat_idx = [0] x_copy = x.clone() for resnet in self.resnets: @@ -752,7 +764,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=None): """ Forward pass through the upsampling block. @@ -764,6 +776,8 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): Returns: torch.Tensor: Output tensor """ + if feat_idx is None: + feat_idx = [0] for resnet in self.resnets: if feat_cache is not None: x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) @@ -869,7 +883,9 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): + if feat_idx is None: + feat_idx = [0] ## conv1 if feat_cache is not None: idx = feat_idx[0] diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 04642ad5d401..f20a0d9a99f1 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -117,7 +117,7 @@ def _determine_device_map( def _fetch_remapped_cls_from_config(config, old_class): previous_class_name = old_class.__name__ - remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None) + remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name, {}).get(config["norm_type"], None) # Details: # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818