diff --git a/acestep/core/generation/handler/generate_music_payload_test.py b/acestep/core/generation/handler/generate_music_payload_test.py index f8dd66ee..47a3ba8a 100644 --- a/acestep/core/generation/handler/generate_music_payload_test.py +++ b/acestep/core/generation/handler/generate_music_payload_test.py @@ -107,6 +107,33 @@ def _progress(value, desc=None): self.assertEqual(payload["extra_outputs"]["pred_latents"].device.type, "cpu") self.assertEqual(progress_calls[0][0], 0.99) + def test_build_success_payload_handles_missing_optional_outputs_without_progress(self): + """It handles absent optional output keys and no progress callback.""" + host = _Host() + outputs = {} + pred_wavs = torch.ones(1, 2, 8) + pred_latents_cpu = torch.ones(1, 4, 3) + time_costs = {"total_time_cost": 2.0} + + payload = host._build_generate_music_success_payload( + outputs=outputs, + pred_wavs=pred_wavs, + pred_latents_cpu=pred_latents_cpu, + time_costs=time_costs, + seed_value_for_ui=11, + actual_batch_size=1, + progress=None, + ) + + self.assertTrue(payload["success"]) + self.assertIsNone(payload["error"]) + self.assertEqual(payload["status_message"], "Generation completed successfully!") + self.assertEqual(payload["extra_outputs"]["spans"], []) + self.assertIsNone(payload["extra_outputs"]["encoder_hidden_states"]) + self.assertIsNone(payload["extra_outputs"]["encoder_attention_mask"]) + self.assertIsNone(payload["extra_outputs"]["context_latents"]) + self.assertEqual(payload["extra_outputs"]["pred_latents"].device.type, "cpu") + if __name__ == "__main__": unittest.main() diff --git a/acestep/core/generation/handler/mlx_dit_init.py b/acestep/core/generation/handler/mlx_dit_init.py index 3db1eb18..73ff68ea 100644 --- a/acestep/core/generation/handler/mlx_dit_init.py +++ b/acestep/core/generation/handler/mlx_dit_init.py @@ -35,7 +35,7 @@ def _init_mlx_dit(self, compile_model: bool = False) -> bool: f"(mx.compile={compile_model})." ) return True - except Exception as exc: + except Exception as exc: # noqa: BLE001 logger.warning(f"[MLX-DiT] Failed to initialize MLX decoder (non-fatal): {exc}") self.mlx_decoder = None self.use_mlx_dit = False diff --git a/acestep/core/generation/handler/mlx_dit_init_test.py b/acestep/core/generation/handler/mlx_dit_init_test.py index 324c2aa1..ba0b707a 100644 --- a/acestep/core/generation/handler/mlx_dit_init_test.py +++ b/acestep/core/generation/handler/mlx_dit_init_test.py @@ -74,7 +74,7 @@ def test_init_mlx_dit_success_sets_decoder(self): fake_dit_model.MLXDiTDecoder = type( "FakeDecoder", (), - {"from_config": classmethod(lambda cls, _cfg: object())}, + {"from_config": classmethod(lambda _cls, _cfg: object())}, ) fake_dit_convert = types.ModuleType("acestep.models.mlx.dit_convert") fake_dit_convert.convert_and_load = Mock() diff --git a/acestep/core/generation/handler/mlx_vae_init_test.py b/acestep/core/generation/handler/mlx_vae_init_test.py index 3bbdab2c..a4b1130d 100644 --- a/acestep/core/generation/handler/mlx_vae_init_test.py +++ b/acestep/core/generation/handler/mlx_vae_init_test.py @@ -1,5 +1,4 @@ """Unit tests for extracted MLX VAE initialization mixin.""" - import importlib.util import os import sys @@ -7,10 +6,17 @@ import unittest from pathlib import Path from unittest.mock import Mock, patch - - -def _load_handler_module(filename: str, module_name: str): - """Load handler mixin module directly from file path.""" +def _load_handler_module(filename: str, module_name: str) -> types.ModuleType: + """Load a handler mixin module for isolated tests. + + Args: + filename: Module filename in the current test directory. + module_name: Fully-qualified module name used for import execution. + Returns: + Loaded module object. + Raises: + FileNotFoundError, ImportError, SyntaxError: On module load failures. + """ repo_root = Path(__file__).resolve().parents[4] if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) @@ -20,27 +26,32 @@ def _load_handler_module(filename: str, module_name: str): "acestep.core.generation": repo_root / "acestep" / "core" / "generation", "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", } - for package_name, package_path in package_paths.items(): - if package_name in sys.modules: - continue - package_module = types.ModuleType(package_name) - package_module.__path__ = [str(package_path)] - sys.modules[package_name] = package_module - module_path = Path(__file__).with_name(filename) - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - assert spec.loader is not None - spec.loader.exec_module(module) - return module - - + previous_modules = {name: sys.modules.get(name) for name in package_paths} + try: + for package_name, package_path in package_paths.items(): + if package_name in sys.modules: + continue + package_module = types.ModuleType(package_name) + package_module.__path__ = [str(package_path)] + sys.modules[package_name] = package_module + module_path = Path(__file__).with_name(filename) + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load module spec for {module_name}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + finally: + for package_name, previous in previous_modules.items(): + if previous is None: + sys.modules.pop(package_name, None) + else: + sys.modules[package_name] = previous MLX_VAE_INIT_MODULE = _load_handler_module( "mlx_vae_init.py", "acestep.core.generation.handler.mlx_vae_init", ) MlxVaeInitMixin = MLX_VAE_INIT_MODULE.MlxVaeInitMixin - - class _VaeHost(MlxVaeInitMixin): """Minimal host exposing VAE init state used by tests.""" @@ -50,8 +61,6 @@ def __init__(self): self.mlx_vae = None self.use_mlx_vae = False self._mlx_vae_dtype = None - - class _FakeMlxVae: """Simple fake MLX VAE object with decode/encode callables.""" @@ -75,10 +84,19 @@ def update(self, _params): def parameters(self): """Return placeholder params tree.""" return {} +def _build_fake_mx_core(raise_compile: bool) -> tuple[types.ModuleType, dict[str, int]]: + """Build fake ``mlx.core`` and compile-call tracking for tests. + + Args: + raise_compile: Whether inner ``_compile`` raises ``CompileError``. + Returns: + ``(fake_mx_core, calls)`` where ``calls`` tracks ``_compile`` count. + Raises: + CompileError: Raised by ``_compile`` when ``raise_compile`` is True. + """ + class CompileError(RuntimeError): + """Raised when fake MLX compile is configured to fail.""" - -def _build_fake_mx_core(raise_compile: bool): - """Build fake ``mlx.core`` module with optional compile failure.""" fake_mx_core = types.ModuleType("mlx.core") fake_mx_core.float16 = "float16" fake_mx_core.float32 = "float32" @@ -92,13 +110,11 @@ def _compile(fn): """Track compile invocations and optionally simulate compile failure.""" calls["compile"] += 1 if raise_compile: - raise RuntimeError("compile failed") + raise CompileError("compile failed") return fn fake_mx_core.compile = _compile return fake_mx_core, calls - - class MlxVaeInitMixinTests(unittest.TestCase): """Behavior tests for extracted ``MlxVaeInitMixin``.""" @@ -125,7 +141,7 @@ def test_init_mlx_vae_success_sets_compiled_callables(self): fake_mlx_pkg = types.ModuleType("mlx") fake_mlx_pkg.__path__ = [] fake_utils = types.ModuleType("mlx.utils") - fake_utils.tree_map = lambda fn, params: params + fake_utils.tree_map = lambda _fn, params: params with patch.dict( sys.modules, { @@ -156,7 +172,7 @@ def test_init_mlx_vae_compile_failure_falls_back(self): fake_mlx_pkg = types.ModuleType("mlx") fake_mlx_pkg.__path__ = [] fake_utils = types.ModuleType("mlx.utils") - fake_utils.tree_map = lambda fn, params: params + fake_utils.tree_map = lambda _fn, params: params with patch.dict( sys.modules, { @@ -171,7 +187,5 @@ def test_init_mlx_vae_compile_failure_falls_back(self): self.assertTrue(host._init_mlx_vae()) self.assertEqual(host._mlx_compiled_decode("x"), host.mlx_vae.decode("x")) self.assertEqual(host._mlx_compiled_encode_sample("x"), host.mlx_vae.encode_and_sample("x")) - - if __name__ == "__main__": unittest.main() diff --git a/acestep/core/generation/handler/mlx_vae_native_test.py b/acestep/core/generation/handler/mlx_vae_native_test.py index b13a7048..ef895bc3 100644 --- a/acestep/core/generation/handler/mlx_vae_native_test.py +++ b/acestep/core/generation/handler/mlx_vae_native_test.py @@ -1,5 +1,4 @@ """Unit tests for extracted native MLX VAE encode/decode mixins.""" - import importlib.util import sys import types @@ -9,8 +8,6 @@ import numpy as np import torch from unittest.mock import patch - - def _load_handler_module(module_filename: str, module_name: str): """Load a handler mixin module directly from disk for isolated tests.""" repo_root = Path(__file__).resolve().parents[4] @@ -34,8 +31,6 @@ def _load_handler_module(module_filename: str, module_name: str): assert spec.loader is not None spec.loader.exec_module(module) return module - - MLX_VAE_DECODE_NATIVE_MODULE = _load_handler_module( "mlx_vae_decode_native.py", "acestep.core.generation.handler.mlx_vae_decode_native",