Skip to content

Commit 86f7a1f

Browse files
committed
fix jit classmethod -> staticmethod
1 parent f3471dd commit 86f7a1f

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

mindone/diffusers/models/attention_dispatch.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -158,37 +158,40 @@ class _AttentionBackendRegistry:
158158
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
159159
_checks_enabled = DIFFUSERS_ATTN_CHECKS
160160

161-
@classmethod
161+
@staticmethod
162162
def register(
163-
cls,
164163
backend: AttentionBackendName,
165164
constraints: Optional[List[Callable]] = None,
166165
supports_context_parallel: bool = False,
167166
):
167+
Registry = _AttentionBackendRegistry
168168
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
169169

170170
def decorator(func):
171-
cls._backends[backend] = func
172-
cls._constraints[backend] = constraints or []
173-
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
174-
cls._supports_context_parallel[backend] = supports_context_parallel
171+
Registry._backends[backend] = func
172+
Registry._constraints[backend] = constraints or []
173+
Registry._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
174+
Registry._supports_context_parallel[backend] = supports_context_parallel
175175
return func
176176

177177
return decorator
178178

179-
@classmethod
180-
def get_active_backend(cls):
181-
return cls._active_backend, cls._backends[cls._active_backend]
179+
@staticmethod
180+
def get_active_backend():
181+
Registry = _AttentionBackendRegistry
182+
return Registry._active_backend, Registry._backends[Registry._active_backend]
182183

183-
@classmethod
184-
def list_backends(cls):
185-
return list(cls._backends.keys())
184+
@staticmethod
185+
def list_backends():
186+
Registry = _AttentionBackendRegistry
187+
return list(Registry._backends.keys())
186188

187-
@classmethod
189+
@staticmethod
188190
def _is_context_parallel_enabled(
189-
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
191+
backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
190192
) -> bool:
191-
supports_context_parallel = backend in cls._supports_context_parallel
193+
Registry = _AttentionBackendRegistry
194+
supports_context_parallel = backend in Registry._supports_context_parallel
192195
is_degree_greater_than_1 = parallel_config is not None and (
193196
parallel_config.context_parallel_config.ring_degree > 1
194197
or parallel_config.context_parallel_config.ulysses_degree > 1

mindone/diffusers/models/transformers/transformer_skyreels_v2.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def apply_rotary_emb(
114114
out[..., 1::2] = x1 * sin + x2 * cos
115115
return out.type_as(hidden_states)
116116

117-
query = apply_rotary_emb(query, rotary_emb)
118-
key = apply_rotary_emb(key, rotary_emb)
117+
query = apply_rotary_emb(query, *rotary_emb)
118+
key = apply_rotary_emb(key, *rotary_emb)
119119

120120
# I2V task
121121
hidden_states_img = None
@@ -420,13 +420,13 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
420420
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
421421
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
422422

423-
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand((ppf, pph, ppw, -1))
424-
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand((ppf, pph, ppw, -1))
425-
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand((ppf, pph, ppw, -1))
423+
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).broadcast_to((ppf, pph, ppw, -1))
424+
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).broadcast_to((ppf, pph, ppw, -1))
425+
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).broadcast_to((ppf, pph, ppw, -1))
426426

427-
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand((ppf, pph, ppw, -1))
428-
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand((ppf, pph, ppw, -1))
429-
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand((ppf, pph, ppw, -1))
427+
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).broadcast_to((ppf, pph, ppw, -1))
428+
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).broadcast_to((ppf, pph, ppw, -1))
429+
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).broadcast_to((ppf, pph, ppw, -1))
430430

431431
freqs_cos = mint.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
432432
freqs_sin = mint.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)

0 commit comments

Comments
 (0)