diff --git a/acestep/core/generation/handler/__init__.py b/acestep/core/generation/handler/__init__.py index 9b0af298..a82955d9 100644 --- a/acestep/core/generation/handler/__init__.py +++ b/acestep/core/generation/handler/__init__.py @@ -17,6 +17,10 @@ from .lora_manager import LoraManagerMixin from .memory_utils import MemoryUtilsMixin from .metadata_utils import MetadataMixin +from .mlx_dit_init import MlxDitInitMixin +from .mlx_vae_decode_native import MlxVaeDecodeNativeMixin +from .mlx_vae_encode_native import MlxVaeEncodeNativeMixin +from .mlx_vae_init import MlxVaeInitMixin from .padding_utils import PaddingMixin from .prompt_utils import PromptMixin from .progress import ProgressMixin @@ -49,6 +53,10 @@ "LoraManagerMixin", "MemoryUtilsMixin", "MetadataMixin", + "MlxDitInitMixin", + "MlxVaeDecodeNativeMixin", + "MlxVaeEncodeNativeMixin", + "MlxVaeInitMixin", "PaddingMixin", "PromptMixin", "ProgressMixin", diff --git a/acestep/core/generation/handler/mlx_dit_init.py b/acestep/core/generation/handler/mlx_dit_init.py new file mode 100644 index 00000000..3db1eb18 --- /dev/null +++ b/acestep/core/generation/handler/mlx_dit_init.py @@ -0,0 +1,43 @@ +"""MLX DiT initialization helpers for Apple Silicon acceleration.""" + +from loguru import logger + + +class MlxDitInitMixin: + """Initialize native MLX DiT decoder state used by generation runtime.""" + + def _init_mlx_dit(self, compile_model: bool = False) -> bool: + """Initialize the MLX DiT decoder when platform support is available. + + Args: + compile_model: Whether MLX diffusion should use ``mx.compile``. + + Returns: + bool: ``True`` when MLX DiT is initialized successfully, else ``False``. + """ + try: + from acestep.models.mlx import mlx_available + + if not mlx_available(): + logger.info("[MLX-DiT] MLX not available on this platform; skipping.") + return False + + from acestep.models.mlx.dit_model import MLXDiTDecoder + from acestep.models.mlx.dit_convert import convert_and_load + + mlx_decoder = MLXDiTDecoder.from_config(self.config) + convert_and_load(self.model, mlx_decoder) + self.mlx_decoder = mlx_decoder + self.use_mlx_dit = True + self.mlx_dit_compiled = compile_model + logger.info( + "[MLX-DiT] Native MLX DiT decoder initialized successfully " + f"(mx.compile={compile_model})." + ) + return True + except Exception as exc: + logger.warning(f"[MLX-DiT] Failed to initialize MLX decoder (non-fatal): {exc}") + self.mlx_decoder = None + self.use_mlx_dit = False + self.mlx_dit_compiled = False + return False diff --git a/acestep/core/generation/handler/mlx_dit_init_test.py b/acestep/core/generation/handler/mlx_dit_init_test.py new file mode 100644 index 00000000..324c2aa1 --- /dev/null +++ b/acestep/core/generation/handler/mlx_dit_init_test.py @@ -0,0 +1,96 @@ +"""Unit tests for extracted MLX DiT initialization mixin.""" + +import importlib.util +import sys +import types +import unittest +from pathlib import Path +from unittest.mock import Mock, patch + + +def _load_handler_module(filename: str, module_name: str): + """Load handler mixin module directly from file path.""" + repo_root = Path(__file__).resolve().parents[4] + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + package_paths = { + "acestep": repo_root / "acestep", + "acestep.core": repo_root / "acestep" / "core", + "acestep.core.generation": repo_root / "acestep" / "core" / "generation", + "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", + } + for package_name, package_path in package_paths.items(): + if package_name in sys.modules: + continue + package_module = types.ModuleType(package_name) + package_module.__path__ = [str(package_path)] + sys.modules[package_name] = package_module + module_path = Path(__file__).with_name(filename) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +MLX_DIT_INIT_MODULE = _load_handler_module( + "mlx_dit_init.py", + "acestep.core.generation.handler.mlx_dit_init", +) +MlxDitInitMixin = MLX_DIT_INIT_MODULE.MlxDitInitMixin + + +class _DitHost(MlxDitInitMixin): + """Minimal host exposing DiT init state used by tests.""" + + def __init__(self): + """Initialize deterministic model/config placeholders.""" + self.config = {"size": "tiny"} + self.model = object() + self.mlx_decoder = None + self.use_mlx_dit = False + self.mlx_dit_compiled = False + + +class MlxDitInitMixinTests(unittest.TestCase): + """Behavior tests for extracted ``MlxDitInitMixin``.""" + + def test_init_mlx_dit_unavailable_returns_false(self): + """It returns False and leaves MLX DiT flags unset when unavailable.""" + host = _DitHost() + fake_mlx = types.ModuleType("acestep.models.mlx") + fake_mlx.mlx_available = lambda: False + with patch.dict(sys.modules, {"acestep.models.mlx": fake_mlx}): + self.assertFalse(host._init_mlx_dit(compile_model=True)) + self.assertIsNone(host.mlx_decoder) + self.assertFalse(host.use_mlx_dit) + + def test_init_mlx_dit_success_sets_decoder(self): + """It loads converted MLX DiT decoder and stores compile flag.""" + host = _DitHost() + fake_mlx = types.ModuleType("acestep.models.mlx") + fake_mlx.mlx_available = lambda: True + fake_dit_model = types.ModuleType("acestep.models.mlx.dit_model") + fake_dit_model.MLXDiTDecoder = type( + "FakeDecoder", + (), + {"from_config": classmethod(lambda cls, _cfg: object())}, + ) + fake_dit_convert = types.ModuleType("acestep.models.mlx.dit_convert") + fake_dit_convert.convert_and_load = Mock() + with patch.dict( + sys.modules, + { + "acestep.models.mlx": fake_mlx, + "acestep.models.mlx.dit_model": fake_dit_model, + "acestep.models.mlx.dit_convert": fake_dit_convert, + }, + ): + self.assertTrue(host._init_mlx_dit(compile_model=True)) + self.assertTrue(host.use_mlx_dit) + self.assertTrue(host.mlx_dit_compiled) + fake_dit_convert.convert_and_load.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/core/generation/handler/mlx_vae_decode_native.py b/acestep/core/generation/handler/mlx_vae_decode_native.py new file mode 100644 index 00000000..23b5c21e --- /dev/null +++ b/acestep/core/generation/handler/mlx_vae_decode_native.py @@ -0,0 +1,119 @@ +"""Native MLX VAE decode helpers for latent-to-audio conversion.""" + +import math +import time as _time + +import numpy as np +import torch +from loguru import logger +from tqdm import tqdm + + +class MlxVaeDecodeNativeMixin: + """Decode MLX latents with optional overlap-discard tiling.""" + + def _resolve_mlx_decode_fn(self): + """Resolve the active MLX decode callable from compiled or model state. + + Returns: + Any: Callable that decodes ``[1, T, C]`` MLX latents. + + Raises: + RuntimeError: If no compiled callable exists and ``self.mlx_vae`` is missing. + """ + decode_fn = getattr(self, "_mlx_compiled_decode", None) + if decode_fn is not None: + return decode_fn + if self.mlx_vae is None: + raise RuntimeError("MLX VAE decode requested but mlx_vae is not initialized.") + return self.mlx_vae.decode + + def _mlx_vae_decode(self, latents_torch): + """Decode batched PyTorch latents using native MLX VAE decode. + + Args: + latents_torch: Latent tensor shaped ``[batch, channels, frames]``. + + Returns: + torch.Tensor: Decoded audio shaped ``[batch, channels, samples]``. + """ + import mlx.core as mx + + t_start = _time.time() + latents_np = latents_torch.detach().cpu().float().numpy() + latents_nlc = np.transpose(latents_np, (0, 2, 1)) + batch_size = latents_nlc.shape[0] + latent_frames = latents_nlc.shape[1] + + vae_dtype = getattr(self, "_mlx_vae_dtype", mx.float32) + latents_mx = mx.array(latents_nlc).astype(vae_dtype) + t_convert = _time.time() + + decode_fn = self._resolve_mlx_decode_fn() + audio_parts = [] + for idx in range(batch_size): + decoded = self._mlx_decode_single(latents_mx[idx : idx + 1], decode_fn=decode_fn) + if decoded.dtype != mx.float32: + decoded = decoded.astype(mx.float32) + mx.eval(decoded) + audio_parts.append(np.array(decoded)) + mx.clear_cache() + + t_decode = _time.time() + audio_nlc = np.concatenate(audio_parts, axis=0) + audio_ncl = np.transpose(audio_nlc, (0, 2, 1)) + elapsed = _time.time() - t_start + logger.info( + f"[MLX-VAE] Decoded {batch_size} sample(s), {latent_frames} latent frames -> " + f"audio in {elapsed:.2f}s " + f"(convert={t_convert - t_start:.3f}s, decode={t_decode - t_convert:.2f}s, " + f"dtype={vae_dtype})" + ) + return torch.from_numpy(audio_ncl) + + def _mlx_decode_single(self, z_nlc, decode_fn=None): + """Decode a single MLX latent sample with optional tiling. + + Args: + z_nlc: MLX array in ``[1, frames, channels]`` layout. + decode_fn: Optional decode callable; falls back to compiled decode. + + Returns: + Any: MLX array in ``[1, samples, channels]`` layout. + """ + import mlx.core as mx + + if decode_fn is None: + decode_fn = self._resolve_mlx_decode_fn() + + latent_frames = z_nlc.shape[1] + mlx_chunk = 2048 + mlx_overlap = 64 + + if latent_frames <= mlx_chunk: + return decode_fn(z_nlc) + + stride = mlx_chunk - 2 * mlx_overlap + num_steps = math.ceil(latent_frames / stride) + decoded_parts = [] + upsample_factor = None + + for idx in tqdm(range(num_steps), desc="Decoding audio chunks", disable=self.disable_tqdm): + core_start = idx * stride + core_end = min(core_start + stride, latent_frames) + win_start = max(0, core_start - mlx_overlap) + win_end = min(latent_frames, core_end + mlx_overlap) + + chunk = z_nlc[:, win_start:win_end, :] + audio_chunk = decode_fn(chunk) + mx.eval(audio_chunk) + if upsample_factor is None: + upsample_factor = audio_chunk.shape[1] / chunk.shape[1] + + trim_start = int(round((core_start - win_start) * upsample_factor)) + trim_end = int(round((win_end - core_end) * upsample_factor)) + audio_len = audio_chunk.shape[1] + end_idx = audio_len - trim_end if trim_end > 0 else audio_len + decoded_parts.append(audio_chunk[:, trim_start:end_idx, :]) + + return mx.concatenate(decoded_parts, axis=1) diff --git a/acestep/core/generation/handler/mlx_vae_encode_native.py b/acestep/core/generation/handler/mlx_vae_encode_native.py new file mode 100644 index 00000000..36ddfe9a --- /dev/null +++ b/acestep/core/generation/handler/mlx_vae_encode_native.py @@ -0,0 +1,141 @@ +"""Native MLX VAE encode helpers for audio-to-latent conversion.""" + +import math +import time as _time + +import numpy as np +import torch +from loguru import logger +from tqdm import tqdm + + +class MlxVaeEncodeNativeMixin: + """Encode MLX audio samples into latent tensors with overlap-discard tiling.""" + + def _resolve_mlx_encode_fn(self): + """Resolve the active MLX encode callable from compiled or model state. + + Returns: + Any: Callable that encodes ``[1, S, C]`` MLX audio samples. + + Raises: + RuntimeError: If no compiled callable exists and ``self.mlx_vae`` is missing. + """ + encode_fn = getattr(self, "_mlx_compiled_encode_sample", None) + if encode_fn is not None: + return encode_fn + if self.mlx_vae is None: + raise RuntimeError("MLX VAE encode requested but mlx_vae is not initialized.") + return self.mlx_vae.encode_and_sample + + def _mlx_vae_encode_sample(self, audio_torch): + """Encode batched PyTorch audio to MLX latents. + + Args: + audio_torch: Audio tensor shaped ``[batch, channels, samples]``. + + Returns: + torch.Tensor: Latent tensor shaped ``[batch, channels, frames]``. + """ + import mlx.core as mx + + audio_np = audio_torch.detach().cpu().float().numpy() + audio_nlc = np.transpose(audio_np, (0, 2, 1)) + batch_size = audio_nlc.shape[0] + sample_frames = audio_nlc.shape[1] + + mlx_encode_chunk = 48000 * 30 + mlx_encode_overlap = 48000 * 2 + if sample_frames <= mlx_encode_chunk: + chunks_per_sample = 1 + else: + stride = mlx_encode_chunk - 2 * mlx_encode_overlap + chunks_per_sample = math.ceil(sample_frames / stride) + total_work = batch_size * chunks_per_sample + + t_start = _time.time() + vae_dtype = getattr(self, "_mlx_vae_dtype", mx.float32) + encode_fn = self._resolve_mlx_encode_fn() + + latent_parts = [] + pbar = tqdm( + total=total_work, + desc=f"MLX VAE Encode (native, n={batch_size})", + disable=self.disable_tqdm, + unit="chunk", + ) + for idx in range(batch_size): + single = mx.array(audio_nlc[idx : idx + 1]) + if single.dtype != vae_dtype: + single = single.astype(vae_dtype) + latent = self._mlx_encode_single(single, pbar=pbar, encode_fn=encode_fn) + if latent.dtype != mx.float32: + latent = latent.astype(mx.float32) + mx.eval(latent) + latent_parts.append(np.array(latent)) + mx.clear_cache() + pbar.close() + + elapsed = _time.time() - t_start + logger.info( + f"[MLX-VAE] Encoded {batch_size} sample(s), {sample_frames} audio frames -> " + f"latent in {elapsed:.2f}s (dtype={vae_dtype})" + ) + + latent_nlc = np.concatenate(latent_parts, axis=0) + latent_ncl = np.transpose(latent_nlc, (0, 2, 1)) + return torch.from_numpy(latent_ncl) + + def _mlx_encode_single(self, audio_nlc, pbar=None, encode_fn=None): + """Encode one MLX audio sample with optional overlap-discard tiling. + + Args: + audio_nlc: MLX array in ``[1, samples, channels]`` layout. + pbar: Optional progress-bar object with ``update``. + encode_fn: Optional encode callable; falls back to compiled encode. + + Returns: + Any: MLX array in ``[1, frames, channels]`` layout. + """ + import mlx.core as mx + + if encode_fn is None: + encode_fn = self._resolve_mlx_encode_fn() + + sample_frames = audio_nlc.shape[1] + mlx_encode_chunk = 48000 * 30 + mlx_encode_overlap = 48000 * 2 + + if sample_frames <= mlx_encode_chunk: + result = encode_fn(audio_nlc) + mx.eval(result) + if pbar is not None: + pbar.update(1) + return result + + stride = mlx_encode_chunk - 2 * mlx_encode_overlap + num_steps = math.ceil(sample_frames / stride) + encoded_parts = [] + downsample_factor = None + + for idx in range(num_steps): + core_start = idx * stride + core_end = min(core_start + stride, sample_frames) + win_start = max(0, core_start - mlx_encode_overlap) + win_end = min(sample_frames, core_end + mlx_encode_overlap) + + chunk = audio_nlc[:, win_start:win_end, :] + latent_chunk = encode_fn(chunk) + mx.eval(latent_chunk) + if downsample_factor is None: + downsample_factor = chunk.shape[1] / latent_chunk.shape[1] + + trim_start = int(round((core_start - win_start) / downsample_factor)) + trim_end = int(round((win_end - core_end) / downsample_factor)) + latent_len = latent_chunk.shape[1] + end_idx = latent_len - trim_end if trim_end > 0 else latent_len + encoded_parts.append(latent_chunk[:, trim_start:end_idx, :]) + if pbar is not None: + pbar.update(1) + + return mx.concatenate(encoded_parts, axis=1) diff --git a/acestep/core/generation/handler/mlx_vae_init.py b/acestep/core/generation/handler/mlx_vae_init.py new file mode 100644 index 00000000..835262b2 --- /dev/null +++ b/acestep/core/generation/handler/mlx_vae_init.py @@ -0,0 +1,96 @@ +"""MLX VAE initialization helpers for Apple Silicon acceleration.""" + +import os +from typing import Any + +from loguru import logger + + +class MlxVaeInitMixin: + """Initialize native MLX VAE state and compiled decode/encode callables.""" + + def _init_mlx_vae(self) -> bool: + """Initialize native MLX VAE runtime state from ``self.vae``. + + The ``_init_mlx_vae`` path converts the loaded PyTorch VAE in + ``self.vae`` to an MLX implementation, optionally applies float16 + conversion based on ``ACESTEP_MLX_VAE_FP16``, and prepares decode/encode + callables. + + Side effects: + Mutates ``self.mlx_vae`` and ``self.use_mlx_vae`` and updates + ``self._mlx_compiled_decode``, ``self._mlx_compiled_encode_sample``, + and ``self._mlx_vae_dtype``. + + Returns: + bool: ``True`` when MLX VAE is initialized successfully, else ``False``. + + Error behavior: + Returns ``False`` when MLX is unavailable or when any conversion/ + initialization step raises an exception. Failures are logged as + non-fatal. + """ + try: + from acestep.models.mlx import mlx_available + + if not mlx_available(): + logger.info("[MLX-VAE] MLX not available on this platform; skipping.") + return False + + import mlx.core as mx + from mlx.utils import tree_map + from acestep.models.mlx.vae_model import MLXAutoEncoderOobleck + from acestep.models.mlx.vae_convert import convert_and_load + + mlx_vae = MLXAutoEncoderOobleck.from_pytorch_config(self.vae) + convert_and_load(self.vae, mlx_vae) + + use_fp16 = os.environ.get("ACESTEP_MLX_VAE_FP16", "0").lower() in ( + "1", + "true", + "yes", + ) + vae_dtype = mx.float16 if use_fp16 else mx.float32 + + if use_fp16: + try: + + def _to_fp16(value: Any): + """Cast floating MLX arrays to float16 while preserving other values.""" + if isinstance(value, mx.array) and mx.issubdtype(value.dtype, mx.floating): + return value.astype(mx.float16) + return value + + mlx_vae.update(tree_map(_to_fp16, mlx_vae.parameters())) + mx.eval(mlx_vae.parameters()) + logger.info("[MLX-VAE] Model weights converted to float16.") + except Exception as exc: + logger.warning(f"[MLX-VAE] Float16 conversion failed ({exc}); using float32.") + vae_dtype = mx.float32 + + compiled = True + try: + self._mlx_compiled_decode = mx.compile(mlx_vae.decode) + self._mlx_compiled_encode_sample = mx.compile(mlx_vae.encode_and_sample) + logger.info("[MLX-VAE] Decode/encode compiled with mx.compile().") + except Exception as exc: + compiled = False + logger.warning(f"[MLX-VAE] mx.compile() failed ({exc}); using uncompiled path.") + self._mlx_compiled_decode = mlx_vae.decode + self._mlx_compiled_encode_sample = mlx_vae.encode_and_sample + + self.mlx_vae = mlx_vae + self.use_mlx_vae = True + self._mlx_vae_dtype = vae_dtype + logger.info( + f"[MLX-VAE] Native MLX VAE initialized (dtype={vae_dtype}, compiled={compiled})." + ) + return True + except Exception as exc: + logger.warning(f"[MLX-VAE] Failed to initialize MLX VAE (non-fatal): {exc}") + self.mlx_vae = None + self.use_mlx_vae = False + self._mlx_compiled_decode = None + self._mlx_compiled_encode_sample = None + self._mlx_vae_dtype = None + return False diff --git a/acestep/core/generation/handler/mlx_vae_init_test.py b/acestep/core/generation/handler/mlx_vae_init_test.py new file mode 100644 index 00000000..3bbdab2c --- /dev/null +++ b/acestep/core/generation/handler/mlx_vae_init_test.py @@ -0,0 +1,177 @@ +"""Unit tests for extracted MLX VAE initialization mixin.""" + +import importlib.util +import os +import sys +import types +import unittest +from pathlib import Path +from unittest.mock import Mock, patch + + +def _load_handler_module(filename: str, module_name: str): + """Load handler mixin module directly from file path.""" + repo_root = Path(__file__).resolve().parents[4] + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + package_paths = { + "acestep": repo_root / "acestep", + "acestep.core": repo_root / "acestep" / "core", + "acestep.core.generation": repo_root / "acestep" / "core" / "generation", + "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", + } + for package_name, package_path in package_paths.items(): + if package_name in sys.modules: + continue + package_module = types.ModuleType(package_name) + package_module.__path__ = [str(package_path)] + sys.modules[package_name] = package_module + module_path = Path(__file__).with_name(filename) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +MLX_VAE_INIT_MODULE = _load_handler_module( + "mlx_vae_init.py", + "acestep.core.generation.handler.mlx_vae_init", +) +MlxVaeInitMixin = MLX_VAE_INIT_MODULE.MlxVaeInitMixin + + +class _VaeHost(MlxVaeInitMixin): + """Minimal host exposing VAE init state used by tests.""" + + def __init__(self): + """Initialize deterministic VAE placeholders.""" + self.vae = object() + self.mlx_vae = None + self.use_mlx_vae = False + self._mlx_vae_dtype = None + + +class _FakeMlxVae: + """Simple fake MLX VAE object with decode/encode callables.""" + + def decode(self, value): + """Return tagged decode output.""" + return ("decode", value) + + def encode_and_sample(self, value): + """Return tagged encode output.""" + return ("encode", value) + + @classmethod + def from_pytorch_config(cls, _vae): + """Build fake MLX VAE from PyTorch VAE placeholder.""" + return cls() + + def update(self, _params): + """Accept parameter update requests.""" + return None + + def parameters(self): + """Return placeholder params tree.""" + return {} + + +def _build_fake_mx_core(raise_compile: bool): + """Build fake ``mlx.core`` module with optional compile failure.""" + fake_mx_core = types.ModuleType("mlx.core") + fake_mx_core.float16 = "float16" + fake_mx_core.float32 = "float32" + fake_mx_core.floating = "floating" + fake_mx_core.array = tuple + fake_mx_core.issubdtype = lambda *_args, **_kwargs: False + fake_mx_core.eval = lambda *_args, **_kwargs: None + calls = {"compile": 0} + + def _compile(fn): + """Track compile invocations and optionally simulate compile failure.""" + calls["compile"] += 1 + if raise_compile: + raise RuntimeError("compile failed") + return fn + + fake_mx_core.compile = _compile + return fake_mx_core, calls + + +class MlxVaeInitMixinTests(unittest.TestCase): + """Behavior tests for extracted ``MlxVaeInitMixin``.""" + + def test_init_mlx_vae_unavailable_returns_false(self): + """It returns False and leaves MLX VAE flags unset when unavailable.""" + host = _VaeHost() + fake_mlx = types.ModuleType("acestep.models.mlx") + fake_mlx.mlx_available = lambda: False + with patch.dict(sys.modules, {"acestep.models.mlx": fake_mlx}): + self.assertFalse(host._init_mlx_vae()) + self.assertIsNone(host.mlx_vae) + self.assertFalse(host.use_mlx_vae) + + def test_init_mlx_vae_success_sets_compiled_callables(self): + """It initializes MLX VAE and stores compiled decode/encode callables.""" + host = _VaeHost() + fake_mx_core, calls = _build_fake_mx_core(raise_compile=False) + fake_mlx = types.ModuleType("acestep.models.mlx") + fake_mlx.mlx_available = lambda: True + fake_vae_model = types.ModuleType("acestep.models.mlx.vae_model") + fake_vae_model.MLXAutoEncoderOobleck = _FakeMlxVae + fake_vae_convert = types.ModuleType("acestep.models.mlx.vae_convert") + fake_vae_convert.convert_and_load = Mock() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + fake_utils = types.ModuleType("mlx.utils") + fake_utils.tree_map = lambda fn, params: params + with patch.dict( + sys.modules, + { + "mlx": fake_mlx_pkg, + "mlx.core": fake_mx_core, + "mlx.utils": fake_utils, + "acestep.models.mlx": fake_mlx, + "acestep.models.mlx.vae_model": fake_vae_model, + "acestep.models.mlx.vae_convert": fake_vae_convert, + }, + ): + with patch.dict(os.environ, {"ACESTEP_MLX_VAE_FP16": "0"}, clear=False): + self.assertTrue(host._init_mlx_vae()) + self.assertEqual(calls["compile"], 2) + self.assertTrue(host.use_mlx_vae) + self.assertEqual(host._mlx_vae_dtype, "float32") + + def test_init_mlx_vae_compile_failure_falls_back(self): + """It keeps uncompiled methods when ``mx.compile`` fails.""" + host = _VaeHost() + fake_mx_core, _calls = _build_fake_mx_core(raise_compile=True) + fake_mlx = types.ModuleType("acestep.models.mlx") + fake_mlx.mlx_available = lambda: True + fake_vae_model = types.ModuleType("acestep.models.mlx.vae_model") + fake_vae_model.MLXAutoEncoderOobleck = _FakeMlxVae + fake_vae_convert = types.ModuleType("acestep.models.mlx.vae_convert") + fake_vae_convert.convert_and_load = Mock() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + fake_utils = types.ModuleType("mlx.utils") + fake_utils.tree_map = lambda fn, params: params + with patch.dict( + sys.modules, + { + "mlx": fake_mlx_pkg, + "mlx.core": fake_mx_core, + "mlx.utils": fake_utils, + "acestep.models.mlx": fake_mlx, + "acestep.models.mlx.vae_model": fake_vae_model, + "acestep.models.mlx.vae_convert": fake_vae_convert, + }, + ): + self.assertTrue(host._init_mlx_vae()) + self.assertEqual(host._mlx_compiled_decode("x"), host.mlx_vae.decode("x")) + self.assertEqual(host._mlx_compiled_encode_sample("x"), host.mlx_vae.encode_and_sample("x")) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/core/generation/handler/mlx_vae_native_test.py b/acestep/core/generation/handler/mlx_vae_native_test.py new file mode 100644 index 00000000..b13a7048 --- /dev/null +++ b/acestep/core/generation/handler/mlx_vae_native_test.py @@ -0,0 +1,204 @@ +"""Unit tests for extracted native MLX VAE encode/decode mixins.""" + +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import numpy as np +import torch +from unittest.mock import patch + + +def _load_handler_module(module_filename: str, module_name: str): + """Load a handler mixin module directly from disk for isolated tests.""" + repo_root = Path(__file__).resolve().parents[4] + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + package_paths = { + "acestep": repo_root / "acestep", + "acestep.core": repo_root / "acestep" / "core", + "acestep.core.generation": repo_root / "acestep" / "core" / "generation", + "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", + } + for package_name, package_path in package_paths.items(): + if package_name in sys.modules: + continue + package_module = types.ModuleType(package_name) + package_module.__path__ = [str(package_path)] + sys.modules[package_name] = package_module + module_path = Path(__file__).with_name(module_filename) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +MLX_VAE_DECODE_NATIVE_MODULE = _load_handler_module( + "mlx_vae_decode_native.py", + "acestep.core.generation.handler.mlx_vae_decode_native", +) +MLX_VAE_ENCODE_NATIVE_MODULE = _load_handler_module( + "mlx_vae_encode_native.py", + "acestep.core.generation.handler.mlx_vae_encode_native", +) +MlxVaeDecodeNativeMixin = MLX_VAE_DECODE_NATIVE_MODULE.MlxVaeDecodeNativeMixin +MlxVaeEncodeNativeMixin = MLX_VAE_ENCODE_NATIVE_MODULE.MlxVaeEncodeNativeMixin + + +class _Progress: + """Minimal progress object tracking update calls.""" + + def __init__(self, *_args, **_kwargs): + """Initialize an empty update counter.""" + self.count = 0 + + def update(self, amount): + """Record update increments from mixin helpers.""" + self.count += amount + + def close(self): + """Provide close API parity with tqdm.""" + return None + + +class _Host(MlxVaeDecodeNativeMixin, MlxVaeEncodeNativeMixin): + """Minimal host exposing extracted native MLX encode/decode helper methods.""" + + def __init__(self): + """Initialize fake MLX runtime attributes used by helper methods.""" + self.disable_tqdm = True + self._mlx_vae_dtype = np.float32 + self.mlx_vae = types.SimpleNamespace( + decode=lambda chunk: np.repeat(chunk, 2, axis=1), + encode_and_sample=lambda chunk: chunk[:, ::2, :], + ) + self._mlx_compiled_decode = self.mlx_vae.decode + self._mlx_compiled_encode_sample = self.mlx_vae.encode_and_sample + + +def _fake_mx_core_module(): + """Create a minimal fake ``mlx.core`` module backed by NumPy arrays.""" + fake_mx_core = types.ModuleType("mlx.core") + fake_mx_core.float32 = np.float32 + fake_mx_core.float16 = np.float16 + fake_mx_core.array = lambda values: np.array(values) + fake_mx_core.eval = lambda *_args, **_kwargs: None + fake_mx_core.clear_cache = lambda *_args, **_kwargs: None + fake_mx_core.concatenate = lambda values, axis=0: np.concatenate(values, axis=axis) + return fake_mx_core + + +class MlxVaeNativeMixinTests(unittest.TestCase): + """Behavior tests for extracted native MLX VAE encode/decode helpers.""" + + def test_mlx_decode_single_without_tiling_uses_decode_fn(self): + """It decodes short sequences without entering overlap-discard tiling.""" + host = _Host() + fake_mx_core = _fake_mx_core_module() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + z_nlc = np.ones((1, 64, 1), dtype=np.float32) + def decode_fn(chunk): + """Expand latent time axis by factor two for decode test behavior.""" + return np.repeat(chunk, 2, axis=1) + with patch.dict(sys.modules, {"mlx": fake_mlx_pkg, "mlx.core": fake_mx_core}): + out = host._mlx_decode_single(z_nlc, decode_fn=decode_fn) + self.assertEqual(tuple(out.shape), (1, 128, 1)) + + def test_mlx_decode_single_with_tiling_concatenates_trimmed_chunks(self): + """It applies overlap-discard tiling for long latent sequences.""" + host = _Host() + fake_mx_core = _fake_mx_core_module() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + z_nlc = np.ones((1, 2200, 1), dtype=np.float32) + def decode_fn(chunk): + """Expand latent time axis by factor two for tiled decode test behavior.""" + return np.repeat(chunk, 2, axis=1) + with patch.dict(sys.modules, {"mlx": fake_mlx_pkg, "mlx.core": fake_mx_core}): + out = host._mlx_decode_single(z_nlc, decode_fn=decode_fn) + self.assertEqual(tuple(out.shape), (1, 4400, 1)) + + def test_mlx_vae_decode_returns_torch_tensor_with_expected_shape(self): + """It converts NCL latents to MLX, decodes, and returns NCL torch tensor.""" + host = _Host() + fake_mx_core = _fake_mx_core_module() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + latents = torch.ones(2, 1, 32, dtype=torch.float32) + with patch.dict(sys.modules, {"mlx": fake_mlx_pkg, "mlx.core": fake_mx_core}): + out = host._mlx_vae_decode(latents) + self.assertEqual(tuple(out.shape), (2, 1, 64)) + self.assertIsInstance(out, torch.Tensor) + + def test_mlx_encode_single_without_tiling_updates_progress(self): + """It encodes short audio in one pass and updates progress once.""" + host = _Host() + fake_mx_core = _fake_mx_core_module() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + progress = _Progress() + audio_nlc = np.ones((1, 1000, 1), dtype=np.float32) + with patch.dict(sys.modules, {"mlx": fake_mlx_pkg, "mlx.core": fake_mx_core}): + out = host._mlx_encode_single(audio_nlc, pbar=progress) + self.assertEqual(tuple(out.shape), (1, 500, 1)) + self.assertEqual(progress.count, 1) + + def test_mlx_encode_single_with_tiling_updates_progress_per_chunk(self): + """It applies overlap-discard tiling for long audio inputs.""" + host = _Host() + fake_mx_core = _fake_mx_core_module() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + progress = _Progress() + audio_nlc = np.ones((1, 1_500_000, 1), dtype=np.float32) + with patch.dict(sys.modules, {"mlx": fake_mlx_pkg, "mlx.core": fake_mx_core}): + out = host._mlx_encode_single(audio_nlc, pbar=progress) + self.assertTrue(out.shape[1] > 0) + self.assertEqual(progress.count, 2) + + def test_mlx_vae_encode_sample_returns_torch_tensor(self): + """It encodes batched audio and returns NCL torch latents.""" + host = _Host() + fake_mx_core = _fake_mx_core_module() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + audio = torch.ones(2, 1, 1200, dtype=torch.float32) + with patch.dict(sys.modules, {"mlx": fake_mlx_pkg, "mlx.core": fake_mx_core}): + with patch.object(MLX_VAE_ENCODE_NATIVE_MODULE, "tqdm", _Progress): + out = host._mlx_vae_encode_sample(audio) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(tuple(out.shape), (2, 1, 600)) + + def test_mlx_decode_single_raises_when_mlx_vae_missing(self): + """It raises a clear runtime error when decode is requested without MLX VAE.""" + host = _Host() + host._mlx_compiled_decode = None + host.mlx_vae = None + fake_mx_core = _fake_mx_core_module() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + with patch.dict(sys.modules, {"mlx": fake_mlx_pkg, "mlx.core": fake_mx_core}): + with self.assertRaises(RuntimeError) as ctx: + host._mlx_decode_single(np.ones((1, 16, 1), dtype=np.float32)) + self.assertIn("mlx_vae is not initialized", str(ctx.exception)) + + def test_mlx_encode_single_raises_when_mlx_vae_missing(self): + """It raises a clear runtime error when encode is requested without MLX VAE.""" + host = _Host() + host._mlx_compiled_encode_sample = None + host.mlx_vae = None + fake_mx_core = _fake_mx_core_module() + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.__path__ = [] + with patch.dict(sys.modules, {"mlx": fake_mlx_pkg, "mlx.core": fake_mx_core}): + with self.assertRaises(RuntimeError) as ctx: + host._mlx_encode_single(np.ones((1, 1000, 1), dtype=np.float32)) + self.assertIn("mlx_vae is not initialized", str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/handler.py b/acestep/handler.py index 01129958..02cf3879 100644 --- a/acestep/handler.py +++ b/acestep/handler.py @@ -8,23 +8,12 @@ # Disable tokenizers parallelism to avoid fork warning os.environ["TOKENIZERS_PARALLELISM"] = "false" -import math -from copy import deepcopy -import tempfile import traceback -import re -import random -import uuid -import hashlib -import json import threading from typing import Optional, Dict, Any, Tuple, List, Union import torch -import torchaudio -import soundfile as sf import time -from tqdm import tqdm from loguru import logger import warnings @@ -47,9 +36,14 @@ LoraManagerMixin, MemoryUtilsMixin, MetadataMixin, + MlxDitInitMixin, + MlxVaeDecodeNativeMixin, + MlxVaeEncodeNativeMixin, + MlxVaeInitMixin, PaddingMixin, ProgressMixin, PromptMixin, + ServiceGenerateMixin, TrainingPresetMixin, TaskUtilsMixin, VaeDecodeChunksMixin, @@ -84,9 +78,14 @@ class AceStepHandler( LoraManagerMixin, MemoryUtilsMixin, MetadataMixin, + MlxDitInitMixin, + MlxVaeDecodeNativeMixin, + MlxVaeEncodeNativeMixin, + MlxVaeInitMixin, PaddingMixin, ProgressMixin, PromptMixin, + ServiceGenerateMixin, TrainingPresetMixin, TaskUtilsMixin, VaeDecodeChunksMixin, @@ -167,493 +166,6 @@ def __init__(self): self.mlx_vae = None self.use_mlx_vae = False - # ------------------------------------------------------------------ - # MLX DiT acceleration helpers - # ------------------------------------------------------------------ - def _init_mlx_dit(self, compile_model: bool = False) -> bool: - """Try to initialize the native MLX DiT decoder for Apple Silicon. - - Args: - compile_model: If True, the diffusion step will be compiled with - ``mx.compile`` for kernel fusion during generation. The - compilation itself happens lazily in ``mlx_generate_diffusion``. - - Returns True on success, False on failure (non-fatal). - """ - try: - from acestep.models.mlx import mlx_available - if not mlx_available(): - logger.info("[MLX-DiT] MLX not available on this platform; skipping.") - return False - - from acestep.models.mlx.dit_model import MLXDiTDecoder - from acestep.models.mlx.dit_convert import convert_and_load - - mlx_decoder = MLXDiTDecoder.from_config(self.config) - convert_and_load(self.model, mlx_decoder) - self.mlx_decoder = mlx_decoder - self.use_mlx_dit = True - self.mlx_dit_compiled = compile_model - logger.info( - f"[MLX-DiT] Native MLX DiT decoder initialized successfully " - f"(mx.compile={compile_model})." - ) - return True - except Exception as exc: - logger.warning(f"[MLX-DiT] Failed to initialize MLX decoder (non-fatal): {exc}") - self.mlx_decoder = None - self.use_mlx_dit = False - self.mlx_dit_compiled = False - return False - - # ------------------------------------------------------------------ - # MLX VAE acceleration helpers - # ------------------------------------------------------------------ - def _init_mlx_vae(self) -> bool: - """Try to initialize the native MLX VAE for Apple Silicon. - - Converts the PyTorch ``AutoencoderOobleck`` weights into a pure-MLX - re-implementation. The PyTorch VAE is kept as a fallback. - - Performance optimizations applied: - - Float16 inference: ~2x throughput from doubled memory bandwidth - on Apple Silicon. Snake1d uses mixed precision internally. - Set ACESTEP_MLX_VAE_FP16=1 to enable float16 inference. - - mx.compile(): kernel fusion reduces Metal dispatch overhead and - improves data locality (used by mlx-lm, vllm-mlx, mlx-audio). - - Returns True on success, False on failure (non-fatal). - """ - try: - from acestep.models.mlx import mlx_available - if not mlx_available(): - logger.info("[MLX-VAE] MLX not available on this platform; skipping.") - return False - - import os - import mlx.core as mx - from mlx.utils import tree_map - from acestep.models.mlx.vae_model import MLXAutoEncoderOobleck - from acestep.models.mlx.vae_convert import convert_and_load - - mlx_vae = MLXAutoEncoderOobleck.from_pytorch_config(self.vae) - convert_and_load(self.vae, mlx_vae) - - # --- Float16 conversion for faster inference --- - # NOTE: Float16 causes audible quality degradation in the Oobleck - # VAE decoder (the Snake activation and ConvTranspose1d chain - # amplify rounding errors). Default to float32 for quality. - # Set ACESTEP_MLX_VAE_FP16=1 to enable float16 inference. - use_fp16 = os.environ.get("ACESTEP_MLX_VAE_FP16", "0").lower() in ( - "1", "true", "yes", - ) - vae_dtype = mx.float16 if use_fp16 else mx.float32 - - if use_fp16: - try: - def _to_fp16(x): - """Cast floating MLX arrays to float16 and keep other values unchanged.""" - if isinstance(x, mx.array) and mx.issubdtype(x.dtype, mx.floating): - return x.astype(mx.float16) - return x - mlx_vae.update(tree_map(_to_fp16, mlx_vae.parameters())) - mx.eval(mlx_vae.parameters()) - logger.info("[MLX-VAE] Model weights converted to float16.") - except Exception as e: - logger.warning(f"[MLX-VAE] Float16 conversion failed ({e}); using float32.") - vae_dtype = mx.float32 - - # --- Compile decode / encode for kernel fusion --- - try: - self._mlx_compiled_decode = mx.compile(mlx_vae.decode) - self._mlx_compiled_encode_sample = mx.compile(mlx_vae.encode_and_sample) - logger.info("[MLX-VAE] Decode/encode compiled with mx.compile().") - except Exception as e: - logger.warning(f"[MLX-VAE] mx.compile() failed ({e}); using uncompiled path.") - self._mlx_compiled_decode = mlx_vae.decode - self._mlx_compiled_encode_sample = mlx_vae.encode_and_sample - - self.mlx_vae = mlx_vae - self.use_mlx_vae = True - self._mlx_vae_dtype = vae_dtype - logger.info( - f"[MLX-VAE] Native MLX VAE initialized " - f"(dtype={vae_dtype}, compiled=True)." - ) - return True - except Exception as exc: - logger.warning(f"[MLX-VAE] Failed to initialize MLX VAE (non-fatal): {exc}") - self.mlx_vae = None - self.use_mlx_vae = False - return False - - def _mlx_vae_decode(self, latents_torch): - """Decode latents using native MLX VAE. - - Args: - latents_torch: PyTorch tensor [B, C, T] (NCL format). - - Returns: - PyTorch tensor [B, C_audio, T_audio] (NCL format). - """ - import numpy as np - import mlx.core as mx - import time as _time - - t_start = _time.time() - - latents_np = latents_torch.detach().cpu().float().numpy() - latents_nlc = np.transpose(latents_np, (0, 2, 1)) # NCL -> NLC - - B = latents_nlc.shape[0] - T = latents_nlc.shape[1] - - # Convert to model dtype (float16 for speed, float32 fallback) - vae_dtype = getattr(self, '_mlx_vae_dtype', mx.float32) - latents_mx = mx.array(latents_nlc).astype(vae_dtype) - - t_convert = _time.time() - - # Use compiled decode (kernel-fused) when available - decode_fn = getattr(self, '_mlx_compiled_decode', self.mlx_vae.decode) - - # Process batch items sequentially (peak memory stays constant) - audio_parts = [] - for b in range(B): - single = latents_mx[b : b + 1] # [1, T, C] - decoded = self._mlx_decode_single(single, decode_fn=decode_fn) - # Cast back to float32 for downstream torch compatibility - if decoded.dtype != mx.float32: - decoded = decoded.astype(mx.float32) - mx.eval(decoded) - audio_parts.append(np.array(decoded)) - mx.clear_cache() # Free intermediate buffers between samples - - t_decode = _time.time() - - audio_nlc = np.concatenate(audio_parts, axis=0) # [B, T_audio, C_audio] - audio_ncl = np.transpose(audio_nlc, (0, 2, 1)) # NLC -> NCL - - t_elapsed = _time.time() - t_start - logger.info( - f"[MLX-VAE] Decoded {B} sample(s), {T} latent frames -> " - f"audio in {t_elapsed:.2f}s " - f"(convert={t_convert - t_start:.3f}s, decode={t_decode - t_convert:.2f}s, " - f"dtype={vae_dtype})" - ) - - return torch.from_numpy(audio_ncl) - - def _mlx_decode_single(self, z_nlc, decode_fn=None): - """Decode a single sample with optional tiling for very long sequences. - - Args: - z_nlc: MLX array [1, T, C] in NLC format. - decode_fn: Compiled or plain decode callable. Falls back to - ``self._mlx_compiled_decode`` or ``self.mlx_vae.decode``. - - Returns: - MLX array [1, T_audio, C_audio] in NLC format. - """ - import mlx.core as mx - - if decode_fn is None: - decode_fn = getattr(self, '_mlx_compiled_decode', self.mlx_vae.decode) - - T = z_nlc.shape[1] - # MLX unified memory: much larger chunk OK than PyTorch MPS. - # 2048 latent frames ~= 87 seconds of audio; covers nearly all use cases. - MLX_CHUNK = 2048 - MLX_OVERLAP = 64 - - if T <= MLX_CHUNK: - # No tiling needed; caller handles mx.eval() - return decode_fn(z_nlc) - - # Overlap-discard tiling for very long sequences - stride = MLX_CHUNK - 2 * MLX_OVERLAP - num_steps = math.ceil(T / stride) - decoded_parts = [] - upsample_factor = None - - for i in tqdm(range(num_steps), desc="Decoding audio chunks", disable=self.disable_tqdm): - core_start = i * stride - core_end = min(core_start + stride, T) - win_start = max(0, core_start - MLX_OVERLAP) - win_end = min(T, core_end + MLX_OVERLAP) - - chunk = z_nlc[:, win_start:win_end, :] - audio_chunk = decode_fn(chunk) - mx.eval(audio_chunk) - - if upsample_factor is None: - upsample_factor = audio_chunk.shape[1] / chunk.shape[1] - - added_start = core_start - win_start - trim_start = int(round(added_start * upsample_factor)) - added_end = win_end - core_end - trim_end = int(round(added_end * upsample_factor)) - - audio_len = audio_chunk.shape[1] - end_idx = audio_len - trim_end if trim_end > 0 else audio_len - decoded_parts.append(audio_chunk[:, trim_start:end_idx, :]) - - return mx.concatenate(decoded_parts, axis=1) - - def _mlx_vae_encode_sample(self, audio_torch): - """Encode audio and sample latent using native MLX VAE. - - Args: - audio_torch: PyTorch tensor [B, C, S] (NCL format). - - Returns: - PyTorch tensor [B, C_latent, T_latent] (NCL format). - """ - import numpy as np - import mlx.core as mx - import time as _time - - audio_np = audio_torch.detach().cpu().float().numpy() - audio_nlc = np.transpose(audio_np, (0, 2, 1)) # NCL -> NLC - - B = audio_nlc.shape[0] - S = audio_nlc.shape[1] - - # Determine total work units for progress bar - MLX_ENCODE_CHUNK = 48000 * 30 - MLX_ENCODE_OVERLAP = 48000 * 2 - if S <= MLX_ENCODE_CHUNK: - chunks_per_sample = 1 - else: - stride = MLX_ENCODE_CHUNK - 2 * MLX_ENCODE_OVERLAP - chunks_per_sample = math.ceil(S / stride) - total_work = B * chunks_per_sample - - t_start = _time.time() - - # Convert to model dtype (float16 for speed) - vae_dtype = getattr(self, '_mlx_vae_dtype', mx.float32) - # Use compiled encode when available - encode_fn = getattr(self, '_mlx_compiled_encode_sample', self.mlx_vae.encode_and_sample) - - latent_parts = [] - pbar = tqdm( - total=total_work, - desc=f"MLX VAE Encode (native, n={B})", - disable=self.disable_tqdm, - unit="chunk", - ) - for b in range(B): - single = mx.array(audio_nlc[b : b + 1]) # [1, S, C_audio] - if single.dtype != vae_dtype: - single = single.astype(vae_dtype) - latent = self._mlx_encode_single(single, pbar=pbar, encode_fn=encode_fn) - # Cast back to float32 for downstream torch compatibility - if latent.dtype != mx.float32: - latent = latent.astype(mx.float32) - mx.eval(latent) - latent_parts.append(np.array(latent)) - mx.clear_cache() - pbar.close() - - t_elapsed = _time.time() - t_start - logger.info( - f"[MLX-VAE] Encoded {B} sample(s), {S} audio frames -> " - f"latent in {t_elapsed:.2f}s (dtype={vae_dtype})" - ) - - latent_nlc = np.concatenate(latent_parts, axis=0) # [B, T, C_latent] - latent_ncl = np.transpose(latent_nlc, (0, 2, 1)) # NLC -> NCL - return torch.from_numpy(latent_ncl) - - def _mlx_encode_single(self, audio_nlc, pbar=None, encode_fn=None): - """Encode a single audio sample with optional tiling. - - Args: - audio_nlc: MLX array [1, S, C_audio] in NLC format. - pbar: Optional tqdm progress bar to update. - encode_fn: Compiled or plain encode callable. Falls back to - ``self._mlx_compiled_encode_sample`` or - ``self.mlx_vae.encode_and_sample``. - - Returns: - MLX array [1, T_latent, C_latent] in NLC format. - """ - import mlx.core as mx - - if encode_fn is None: - encode_fn = getattr( - self, '_mlx_compiled_encode_sample', self.mlx_vae.encode_and_sample, - ) - - S = audio_nlc.shape[1] - # ~30 sec at 48 kHz (generous for MLX unified memory) - MLX_ENCODE_CHUNK = 48000 * 30 - MLX_ENCODE_OVERLAP = 48000 * 2 - - if S <= MLX_ENCODE_CHUNK: - result = encode_fn(audio_nlc) - mx.eval(result) - if pbar is not None: - pbar.update(1) - return result - - # Overlap-discard tiling - stride = MLX_ENCODE_CHUNK - 2 * MLX_ENCODE_OVERLAP - num_steps = math.ceil(S / stride) - encoded_parts = [] - downsample_factor = None - - for i in range(num_steps): - core_start = i * stride - core_end = min(core_start + stride, S) - win_start = max(0, core_start - MLX_ENCODE_OVERLAP) - win_end = min(S, core_end + MLX_ENCODE_OVERLAP) - - chunk = audio_nlc[:, win_start:win_end, :] - latent_chunk = encode_fn(chunk) - mx.eval(latent_chunk) - - if downsample_factor is None: - downsample_factor = chunk.shape[1] / latent_chunk.shape[1] - - added_start = core_start - win_start - trim_start = int(round(added_start / downsample_factor)) - added_end = win_end - core_end - trim_end = int(round(added_end / downsample_factor)) - - latent_len = latent_chunk.shape[1] - end_idx = latent_len - trim_end if trim_end > 0 else latent_len - encoded_parts.append(latent_chunk[:, trim_start:end_idx, :]) - - if pbar is not None: - pbar.update(1) - - return mx.concatenate(encoded_parts, axis=1) - - @torch.inference_mode() - def service_generate( - self, - captions: Union[str, List[str]], - lyrics: Union[str, List[str]], - keys: Optional[Union[str, List[str]]] = None, - target_wavs: Optional[torch.Tensor] = None, - refer_audios: Optional[List[List[torch.Tensor]]] = None, - metas: Optional[Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]]] = None, - vocal_languages: Optional[Union[str, List[str]]] = None, - infer_steps: int = 60, - guidance_scale: float = 7.0, - seed: Optional[Union[int, List[int]]] = None, - return_intermediate: bool = False, - repainting_start: Optional[Union[float, List[float]]] = None, - repainting_end: Optional[Union[float, List[float]]] = None, - instructions: Optional[Union[str, List[str]]] = None, - audio_cover_strength: float = 1.0, - cover_noise_strength: float = 0.0, - use_adg: bool = False, - cfg_interval_start: float = 0.0, - cfg_interval_end: float = 1.0, - shift: float = 1.0, - audio_code_hints: Optional[Union[str, List[str]]] = None, - infer_method: str = "ode", - timesteps: Optional[List[float]] = None, - ) -> Dict[str, Any]: - """Generate music latents from text/audio conditioning inputs. - - Args: - captions: Caption text(s) describing target music. - lyrics: Lyric text(s) used for lyric conditioning. - keys: Optional sample identifiers. - target_wavs: Optional target audio tensor for repaint/cover. - refer_audios: Optional reference audio tensors for style conditioning. - metas: Optional metadata strings/dicts per sample. - vocal_languages: Optional lyric language code(s). - infer_steps: Diffusion inference steps. - guidance_scale: Classifier-free guidance scale. - seed: Optional single seed or per-sample seed list. - return_intermediate: Reserved compatibility flag (handled by caller flow). - repainting_start: Optional repaint start time(s) in seconds. - repainting_end: Optional repaint end time(s) in seconds. - instructions: Optional instruction text(s) per sample. - audio_cover_strength: Blend strength for cover mode. - cover_noise_strength: Initial-noise blend strength for cover mode. - use_adg: Whether to enable adaptive diffusion guidance. - cfg_interval_start: CFG schedule start ratio. - cfg_interval_end: CFG schedule end ratio. - shift: Diffusion time-shift parameter. - audio_code_hints: Optional serialized audio-code hints. - infer_method: Diffusion method selector. - timesteps: Optional custom timestep schedule. - - Returns: - Dict[str, Any]: Model output payload with latents, masks, spans, timing, and cached - condition tensors required by downstream result handlers. - """ - _ = return_intermediate - normalized = self._normalize_service_generate_inputs( - captions=captions, - lyrics=lyrics, - keys=keys, - metas=metas, - vocal_languages=vocal_languages, - repainting_start=repainting_start, - repainting_end=repainting_end, - instructions=instructions, - audio_code_hints=audio_code_hints, - infer_steps=infer_steps, - seed=seed, - ) - batch = self._prepare_batch( - captions=normalized["captions"], - lyrics=normalized["lyrics"], - keys=normalized["keys"], - target_wavs=target_wavs, - refer_audios=refer_audios, - metas=normalized["metas"], - vocal_languages=normalized["vocal_languages"], - repainting_start=normalized["repainting_start"], - repainting_end=normalized["repainting_end"], - instructions=normalized["instructions"], - audio_code_hints=normalized["audio_code_hints"], - audio_cover_strength=audio_cover_strength, - cover_noise_strength=cover_noise_strength, - ) - payload = self._unpack_service_processed_data(self.preprocess_batch(batch)) - seed_param = self._resolve_service_seed_param(normalized["seed_list"]) - self._ensure_silence_latent_on_device() - generate_kwargs = self._build_service_generate_kwargs( - payload=payload, - seed_param=seed_param, - infer_steps=normalized["infer_steps"], - guidance_scale=guidance_scale, - audio_cover_strength=audio_cover_strength, - cover_noise_strength=cover_noise_strength, - infer_method=infer_method, - use_adg=use_adg, - cfg_interval_start=cfg_interval_start, - cfg_interval_end=cfg_interval_end, - shift=shift, - timesteps=timesteps, - ) - outputs, encoder_hidden_states, encoder_attention_mask, context_latents = ( - self._execute_service_generate_diffusion( - payload=payload, - generate_kwargs=generate_kwargs, - seed_param=seed_param, - infer_method=infer_method, - shift=shift, - audio_cover_strength=audio_cover_strength, - ) - ) - return self._attach_service_generate_outputs( - outputs=outputs, - payload=payload, - batch=batch, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - context_latents=context_latents, - ) - def generate_music( self, captions: str, @@ -992,3 +504,4 @@ def generate_music( "success": False, "error": str(e), } +