Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions acestep/core/generation/handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from .lora_manager import LoraManagerMixin
from .memory_utils import MemoryUtilsMixin
from .metadata_utils import MetadataMixin
from .mlx_dit_init import MlxDitInitMixin
from .mlx_vae_decode_native import MlxVaeDecodeNativeMixin
from .mlx_vae_encode_native import MlxVaeEncodeNativeMixin
from .mlx_vae_init import MlxVaeInitMixin
from .padding_utils import PaddingMixin
from .prompt_utils import PromptMixin
from .progress import ProgressMixin
Expand Down Expand Up @@ -49,6 +53,10 @@
"LoraManagerMixin",
"MemoryUtilsMixin",
"MetadataMixin",
"MlxDitInitMixin",
"MlxVaeDecodeNativeMixin",
"MlxVaeEncodeNativeMixin",
"MlxVaeInitMixin",
"PaddingMixin",
"PromptMixin",
"ProgressMixin",
Expand Down
43 changes: 43 additions & 0 deletions acestep/core/generation/handler/mlx_dit_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""MLX DiT initialization helpers for Apple Silicon acceleration."""

from loguru import logger


class MlxDitInitMixin:
"""Initialize native MLX DiT decoder state used by generation runtime."""

def _init_mlx_dit(self, compile_model: bool = False) -> bool:
"""Initialize the MLX DiT decoder when platform support is available.

Args:
compile_model: Whether MLX diffusion should use ``mx.compile``.

Returns:
bool: ``True`` when MLX DiT is initialized successfully, else ``False``.
"""
try:
from acestep.models.mlx import mlx_available

if not mlx_available():
logger.info("[MLX-DiT] MLX not available on this platform; skipping.")
return False

from acestep.models.mlx.dit_model import MLXDiTDecoder
from acestep.models.mlx.dit_convert import convert_and_load

mlx_decoder = MLXDiTDecoder.from_config(self.config)
convert_and_load(self.model, mlx_decoder)
self.mlx_decoder = mlx_decoder
self.use_mlx_dit = True
self.mlx_dit_compiled = compile_model
logger.info(
"[MLX-DiT] Native MLX DiT decoder initialized successfully "
f"(mx.compile={compile_model})."
)
return True
except Exception as exc:
logger.warning(f"[MLX-DiT] Failed to initialize MLX decoder (non-fatal): {exc}")
self.mlx_decoder = None
self.use_mlx_dit = False
self.mlx_dit_compiled = False
return False
96 changes: 96 additions & 0 deletions acestep/core/generation/handler/mlx_dit_init_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Unit tests for extracted MLX DiT initialization mixin."""

import importlib.util
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import Mock, patch


def _load_handler_module(filename: str, module_name: str):
"""Load handler mixin module directly from file path."""
repo_root = Path(__file__).resolve().parents[4]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
package_paths = {
"acestep": repo_root / "acestep",
"acestep.core": repo_root / "acestep" / "core",
"acestep.core.generation": repo_root / "acestep" / "core" / "generation",
"acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler",
}
for package_name, package_path in package_paths.items():
if package_name in sys.modules:
continue
package_module = types.ModuleType(package_name)
package_module.__path__ = [str(package_path)]
sys.modules[package_name] = package_module
module_path = Path(__file__).with_name(filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module


MLX_DIT_INIT_MODULE = _load_handler_module(
"mlx_dit_init.py",
"acestep.core.generation.handler.mlx_dit_init",
)
MlxDitInitMixin = MLX_DIT_INIT_MODULE.MlxDitInitMixin


class _DitHost(MlxDitInitMixin):
"""Minimal host exposing DiT init state used by tests."""

def __init__(self):
"""Initialize deterministic model/config placeholders."""
self.config = {"size": "tiny"}
self.model = object()
self.mlx_decoder = None
self.use_mlx_dit = False
self.mlx_dit_compiled = False


class MlxDitInitMixinTests(unittest.TestCase):
"""Behavior tests for extracted ``MlxDitInitMixin``."""

def test_init_mlx_dit_unavailable_returns_false(self):
"""It returns False and leaves MLX DiT flags unset when unavailable."""
host = _DitHost()
fake_mlx = types.ModuleType("acestep.models.mlx")
fake_mlx.mlx_available = lambda: False
with patch.dict(sys.modules, {"acestep.models.mlx": fake_mlx}):
self.assertFalse(host._init_mlx_dit(compile_model=True))
self.assertIsNone(host.mlx_decoder)
self.assertFalse(host.use_mlx_dit)

def test_init_mlx_dit_success_sets_decoder(self):
"""It loads converted MLX DiT decoder and stores compile flag."""
host = _DitHost()
fake_mlx = types.ModuleType("acestep.models.mlx")
fake_mlx.mlx_available = lambda: True
fake_dit_model = types.ModuleType("acestep.models.mlx.dit_model")
fake_dit_model.MLXDiTDecoder = type(
"FakeDecoder",
(),
{"from_config": classmethod(lambda cls, _cfg: object())},
)
fake_dit_convert = types.ModuleType("acestep.models.mlx.dit_convert")
fake_dit_convert.convert_and_load = Mock()
with patch.dict(
sys.modules,
{
"acestep.models.mlx": fake_mlx,
"acestep.models.mlx.dit_model": fake_dit_model,
"acestep.models.mlx.dit_convert": fake_dit_convert,
},
):
self.assertTrue(host._init_mlx_dit(compile_model=True))
self.assertTrue(host.use_mlx_dit)
self.assertTrue(host.mlx_dit_compiled)
fake_dit_convert.convert_and_load.assert_called_once()


if __name__ == "__main__":
unittest.main()
119 changes: 119 additions & 0 deletions acestep/core/generation/handler/mlx_vae_decode_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Native MLX VAE decode helpers for latent-to-audio conversion."""

import math
import time as _time

import numpy as np
import torch
from loguru import logger
from tqdm import tqdm


class MlxVaeDecodeNativeMixin:
"""Decode MLX latents with optional overlap-discard tiling."""

def _resolve_mlx_decode_fn(self):
"""Resolve the active MLX decode callable from compiled or model state.

Returns:
Any: Callable that decodes ``[1, T, C]`` MLX latents.

Raises:
RuntimeError: If no compiled callable exists and ``self.mlx_vae`` is missing.
"""
decode_fn = getattr(self, "_mlx_compiled_decode", None)
if decode_fn is not None:
return decode_fn
if self.mlx_vae is None:
raise RuntimeError("MLX VAE decode requested but mlx_vae is not initialized.")
return self.mlx_vae.decode

def _mlx_vae_decode(self, latents_torch):
"""Decode batched PyTorch latents using native MLX VAE decode.

Args:
latents_torch: Latent tensor shaped ``[batch, channels, frames]``.

Returns:
torch.Tensor: Decoded audio shaped ``[batch, channels, samples]``.
"""
import mlx.core as mx

t_start = _time.time()
latents_np = latents_torch.detach().cpu().float().numpy()
latents_nlc = np.transpose(latents_np, (0, 2, 1))
batch_size = latents_nlc.shape[0]
latent_frames = latents_nlc.shape[1]

vae_dtype = getattr(self, "_mlx_vae_dtype", mx.float32)
latents_mx = mx.array(latents_nlc).astype(vae_dtype)
t_convert = _time.time()

decode_fn = self._resolve_mlx_decode_fn()
audio_parts = []
for idx in range(batch_size):
decoded = self._mlx_decode_single(latents_mx[idx : idx + 1], decode_fn=decode_fn)
if decoded.dtype != mx.float32:
decoded = decoded.astype(mx.float32)
mx.eval(decoded)
audio_parts.append(np.array(decoded))
mx.clear_cache()

t_decode = _time.time()
audio_nlc = np.concatenate(audio_parts, axis=0)
audio_ncl = np.transpose(audio_nlc, (0, 2, 1))
elapsed = _time.time() - t_start
logger.info(
f"[MLX-VAE] Decoded {batch_size} sample(s), {latent_frames} latent frames -> "
f"audio in {elapsed:.2f}s "
f"(convert={t_convert - t_start:.3f}s, decode={t_decode - t_convert:.2f}s, "
f"dtype={vae_dtype})"
)
return torch.from_numpy(audio_ncl)

def _mlx_decode_single(self, z_nlc, decode_fn=None):
"""Decode a single MLX latent sample with optional tiling.

Args:
z_nlc: MLX array in ``[1, frames, channels]`` layout.
decode_fn: Optional decode callable; falls back to compiled decode.

Returns:
Any: MLX array in ``[1, samples, channels]`` layout.
"""
import mlx.core as mx

if decode_fn is None:
decode_fn = self._resolve_mlx_decode_fn()

latent_frames = z_nlc.shape[1]
mlx_chunk = 2048
mlx_overlap = 64

if latent_frames <= mlx_chunk:
return decode_fn(z_nlc)

stride = mlx_chunk - 2 * mlx_overlap
num_steps = math.ceil(latent_frames / stride)
decoded_parts = []
upsample_factor = None

for idx in tqdm(range(num_steps), desc="Decoding audio chunks", disable=self.disable_tqdm):
core_start = idx * stride
core_end = min(core_start + stride, latent_frames)
win_start = max(0, core_start - mlx_overlap)
win_end = min(latent_frames, core_end + mlx_overlap)

chunk = z_nlc[:, win_start:win_end, :]
audio_chunk = decode_fn(chunk)
mx.eval(audio_chunk)
if upsample_factor is None:
upsample_factor = audio_chunk.shape[1] / chunk.shape[1]

trim_start = int(round((core_start - win_start) * upsample_factor))
trim_end = int(round((win_end - core_end) * upsample_factor))
audio_len = audio_chunk.shape[1]
end_idx = audio_len - trim_end if trim_end > 0 else audio_len
decoded_parts.append(audio_chunk[:, trim_start:end_idx, :])

return mx.concatenate(decoded_parts, axis=1)
Loading