Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions acestep/core/generation/handler/generate_music_payload_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,33 @@ def _progress(value, desc=None):
self.assertEqual(payload["extra_outputs"]["pred_latents"].device.type, "cpu")
self.assertEqual(progress_calls[0][0], 0.99)

def test_build_success_payload_handles_missing_optional_outputs_without_progress(self):
"""It handles absent optional output keys and no progress callback."""
host = _Host()
outputs = {}
pred_wavs = torch.ones(1, 2, 8)
pred_latents_cpu = torch.ones(1, 4, 3)
time_costs = {"total_time_cost": 2.0}

payload = host._build_generate_music_success_payload(
outputs=outputs,
pred_wavs=pred_wavs,
pred_latents_cpu=pred_latents_cpu,
time_costs=time_costs,
seed_value_for_ui=11,
actual_batch_size=1,
progress=None,
)

self.assertTrue(payload["success"])
self.assertIsNone(payload["error"])
self.assertEqual(payload["status_message"], "Generation completed successfully!")
self.assertEqual(payload["extra_outputs"]["spans"], [])
self.assertIsNone(payload["extra_outputs"]["encoder_hidden_states"])
self.assertIsNone(payload["extra_outputs"]["encoder_attention_mask"])
self.assertIsNone(payload["extra_outputs"]["context_latents"])
self.assertEqual(payload["extra_outputs"]["pred_latents"].device.type, "cpu")


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion acestep/core/generation/handler/mlx_dit_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _init_mlx_dit(self, compile_model: bool = False) -> bool:
f"(mx.compile={compile_model})."
)
return True
except Exception as exc:
except Exception as exc: # noqa: BLE001
logger.warning(f"[MLX-DiT] Failed to initialize MLX decoder (non-fatal): {exc}")
self.mlx_decoder = None
self.use_mlx_dit = False
Expand Down
2 changes: 1 addition & 1 deletion acestep/core/generation/handler/mlx_dit_init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_init_mlx_dit_success_sets_decoder(self):
fake_dit_model.MLXDiTDecoder = type(
"FakeDecoder",
(),
{"from_config": classmethod(lambda cls, _cfg: object())},
{"from_config": classmethod(lambda _cls, _cfg: object())},
)
fake_dit_convert = types.ModuleType("acestep.models.mlx.dit_convert")
fake_dit_convert.convert_and_load = Mock()
Expand Down
80 changes: 47 additions & 33 deletions acestep/core/generation/handler/mlx_vae_init_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
"""Unit tests for extracted MLX VAE initialization mixin."""

import importlib.util
import os
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import Mock, patch


def _load_handler_module(filename: str, module_name: str):
"""Load handler mixin module directly from file path."""
def _load_handler_module(filename: str, module_name: str) -> types.ModuleType:
"""Load a handler mixin module for isolated tests.

Args:
filename: Module filename in the current test directory.
module_name: Fully-qualified module name used for import execution.
Returns:
Loaded module object.
Raises:
FileNotFoundError, ImportError, SyntaxError: On module load failures.
"""
repo_root = Path(__file__).resolve().parents[4]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
Expand All @@ -20,27 +26,32 @@ def _load_handler_module(filename: str, module_name: str):
"acestep.core.generation": repo_root / "acestep" / "core" / "generation",
"acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler",
}
for package_name, package_path in package_paths.items():
if package_name in sys.modules:
continue
package_module = types.ModuleType(package_name)
package_module.__path__ = [str(package_path)]
sys.modules[package_name] = package_module
module_path = Path(__file__).with_name(filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module


previous_modules = {name: sys.modules.get(name) for name in package_paths}
try:
for package_name, package_path in package_paths.items():
if package_name in sys.modules:
continue
package_module = types.ModuleType(package_name)
package_module.__path__ = [str(package_path)]
sys.modules[package_name] = package_module
module_path = Path(__file__).with_name(filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load module spec for {module_name}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
finally:
for package_name, previous in previous_modules.items():
if previous is None:
sys.modules.pop(package_name, None)
else:
sys.modules[package_name] = previous
MLX_VAE_INIT_MODULE = _load_handler_module(
"mlx_vae_init.py",
"acestep.core.generation.handler.mlx_vae_init",
)
MlxVaeInitMixin = MLX_VAE_INIT_MODULE.MlxVaeInitMixin


class _VaeHost(MlxVaeInitMixin):
"""Minimal host exposing VAE init state used by tests."""

Expand All @@ -50,8 +61,6 @@ def __init__(self):
self.mlx_vae = None
self.use_mlx_vae = False
self._mlx_vae_dtype = None


class _FakeMlxVae:
"""Simple fake MLX VAE object with decode/encode callables."""

Expand All @@ -75,10 +84,19 @@ def update(self, _params):
def parameters(self):
"""Return placeholder params tree."""
return {}
def _build_fake_mx_core(raise_compile: bool) -> tuple[types.ModuleType, dict[str, int]]:
"""Build fake ``mlx.core`` and compile-call tracking for tests.

Args:
raise_compile: Whether inner ``_compile`` raises ``CompileError``.
Returns:
``(fake_mx_core, calls)`` where ``calls`` tracks ``_compile`` count.
Raises:
CompileError: Raised by ``_compile`` when ``raise_compile`` is True.
"""
class CompileError(RuntimeError):
"""Raised when fake MLX compile is configured to fail."""


def _build_fake_mx_core(raise_compile: bool):
"""Build fake ``mlx.core`` module with optional compile failure."""
fake_mx_core = types.ModuleType("mlx.core")
fake_mx_core.float16 = "float16"
fake_mx_core.float32 = "float32"
Expand All @@ -92,13 +110,11 @@ def _compile(fn):
"""Track compile invocations and optionally simulate compile failure."""
calls["compile"] += 1
if raise_compile:
raise RuntimeError("compile failed")
raise CompileError("compile failed")
return fn

fake_mx_core.compile = _compile
return fake_mx_core, calls


class MlxVaeInitMixinTests(unittest.TestCase):
"""Behavior tests for extracted ``MlxVaeInitMixin``."""

Expand All @@ -125,7 +141,7 @@ def test_init_mlx_vae_success_sets_compiled_callables(self):
fake_mlx_pkg = types.ModuleType("mlx")
fake_mlx_pkg.__path__ = []
fake_utils = types.ModuleType("mlx.utils")
fake_utils.tree_map = lambda fn, params: params
fake_utils.tree_map = lambda _fn, params: params
with patch.dict(
sys.modules,
{
Expand Down Expand Up @@ -156,7 +172,7 @@ def test_init_mlx_vae_compile_failure_falls_back(self):
fake_mlx_pkg = types.ModuleType("mlx")
fake_mlx_pkg.__path__ = []
fake_utils = types.ModuleType("mlx.utils")
fake_utils.tree_map = lambda fn, params: params
fake_utils.tree_map = lambda _fn, params: params
with patch.dict(
sys.modules,
{
Expand All @@ -171,7 +187,5 @@ def test_init_mlx_vae_compile_failure_falls_back(self):
self.assertTrue(host._init_mlx_vae())
self.assertEqual(host._mlx_compiled_decode("x"), host.mlx_vae.decode("x"))
self.assertEqual(host._mlx_compiled_encode_sample("x"), host.mlx_vae.encode_and_sample("x"))


if __name__ == "__main__":
unittest.main()
5 changes: 0 additions & 5 deletions acestep/core/generation/handler/mlx_vae_native_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Unit tests for extracted native MLX VAE encode/decode mixins."""

import importlib.util
import sys
import types
Expand All @@ -9,8 +8,6 @@
import numpy as np
import torch
from unittest.mock import patch


def _load_handler_module(module_filename: str, module_name: str):
"""Load a handler mixin module directly from disk for isolated tests."""
repo_root = Path(__file__).resolve().parents[4]
Expand All @@ -34,8 +31,6 @@ def _load_handler_module(module_filename: str, module_name: str):
assert spec.loader is not None
spec.loader.exec_module(module)
return module


MLX_VAE_DECODE_NATIVE_MODULE = _load_handler_module(
"mlx_vae_decode_native.py",
"acestep.core.generation.handler.mlx_vae_decode_native",
Expand Down