diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5c401b9d202b..fa380c6eb1f3 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -580,6 +580,8 @@ def forward( x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(x_item_seqlens): x_attn_mask[i, :seq_len] = 1 + if torch.all(x_attn_mask): + x_attn_mask = None if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.noise_refiner: @@ -608,6 +610,8 @@ def forward( cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(cap_item_seqlens): cap_attn_mask[i, :seq_len] = 1 + if torch.all(cap_attn_mask): + cap_attn_mask = None if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.context_refiner: @@ -633,6 +637,8 @@ def forward( unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(unified_item_seqlens): unified_attn_mask[i, :seq_len] = 1 + if torch.all(unified_attn_mask): + unified_attn_mask = None if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.layers: