From ca3fd1bc6e5c32d0b59acaa34fdb0796ec8f6a41 Mon Sep 17 00:00:00 2001 From: Rich Date: Tue, 17 Feb 2026 22:21:03 +0000 Subject: [PATCH 1/2] refactor(handler): extract generate_music orchestration mixins (part 18) --- acestep/core/generation/handler/__init__.py | 6 + .../core/generation/handler/generate_music.py | 190 ++++++++++ .../handler/generate_music_decode.py | 196 ++++++++++ .../handler/generate_music_decode_test.py | 196 ++++++++++ .../handler/generate_music_payload.py | 91 +++++ .../handler/generate_music_payload_test.py | 112 ++++++ .../generation/handler/generate_music_test.py | 197 ++++++++++ acestep/handler.py | 352 +----------------- 8 files changed, 995 insertions(+), 345 deletions(-) create mode 100644 acestep/core/generation/handler/generate_music.py create mode 100644 acestep/core/generation/handler/generate_music_decode.py create mode 100644 acestep/core/generation/handler/generate_music_decode_test.py create mode 100644 acestep/core/generation/handler/generate_music_payload.py create mode 100644 acestep/core/generation/handler/generate_music_payload_test.py create mode 100644 acestep/core/generation/handler/generate_music_test.py diff --git a/acestep/core/generation/handler/__init__.py b/acestep/core/generation/handler/__init__.py index a82955d9..76c52a7c 100644 --- a/acestep/core/generation/handler/__init__.py +++ b/acestep/core/generation/handler/__init__.py @@ -8,7 +8,10 @@ from .conditioning_target import ConditioningTargetMixin from .conditioning_text import ConditioningTextMixin from .diffusion import DiffusionMixin +from .generate_music import GenerateMusicMixin +from .generate_music_decode import GenerateMusicDecodeMixin from .generate_music_execute import GenerateMusicExecuteMixin +from .generate_music_payload import GenerateMusicPayloadMixin from .generate_music_request import GenerateMusicRequestMixin from .init_service import InitServiceMixin from .io_audio import IoAudioMixin @@ -44,7 +47,10 @@ "ConditioningTargetMixin", "ConditioningTextMixin", "DiffusionMixin", + "GenerateMusicMixin", + "GenerateMusicDecodeMixin", "GenerateMusicExecuteMixin", + "GenerateMusicPayloadMixin", "GenerateMusicRequestMixin", "InitServiceMixin", "IoAudioMixin", diff --git a/acestep/core/generation/handler/generate_music.py b/acestep/core/generation/handler/generate_music.py new file mode 100644 index 00000000..3ae2f76f --- /dev/null +++ b/acestep/core/generation/handler/generate_music.py @@ -0,0 +1,190 @@ +"""Top-level ``generate_music`` orchestration mixin. + +This module provides the public ``generate_music`` entry point extracted from +``AceStepHandler`` so orchestration stays separate from lower-level helpers. +""" + +import traceback +from typing import Any, Dict, List, Optional, Union + +from loguru import logger + +from acestep.constants import DEFAULT_DIT_INSTRUCTION + + +class GenerateMusicMixin: + """Coordinate request prep, service execution, decode, and payload assembly. + + The host class is expected to implement helper methods invoked by this + orchestration flow. + """ + + def generate_music( + self, + captions: str, + lyrics: str, + bpm: Optional[int] = None, + key_scale: str = "", + time_signature: str = "", + vocal_language: str = "en", + inference_steps: int = 8, + guidance_scale: float = 7.0, + use_random_seed: bool = True, + seed: Optional[Union[str, float, int]] = -1, + reference_audio=None, + audio_duration: Optional[float] = None, + batch_size: Optional[int] = None, + src_audio=None, + audio_code_string: Union[str, List[str]] = "", + repainting_start: float = 0.0, + repainting_end: Optional[float] = None, + instruction: str = DEFAULT_DIT_INSTRUCTION, + audio_cover_strength: float = 1.0, + cover_noise_strength: float = 0.0, + task_type: str = "text2music", + use_adg: bool = False, + cfg_interval_start: float = 0.0, + cfg_interval_end: float = 1.0, + shift: float = 1.0, + infer_method: str = "ode", + use_tiled_decode: bool = True, + timesteps: Optional[List[float]] = None, + latent_shift: float = 0.0, + latent_rescale: float = 1.0, + progress=None, + ) -> Dict[str, Any]: + """Generate audio from text/reference inputs and return response payload. + + Args: + captions: Text prompt describing requested music. + lyrics: Lyric text used for conditioning. + reference_audio: Optional reference-audio payload. + src_audio: Optional source audio for repaint/cover. + inference_steps: Diffusion step count. + guidance_scale: CFG guidance value. + seed: Optional explicit seed from caller/UI. + infer_method: Diffusion method name. + timesteps: Optional custom timestep schedule. + use_tiled_decode: Whether tiled VAE decode is used. + latent_shift: Additive latent post-processing value. + latent_rescale: Multiplicative latent post-processing value. + progress: Optional callback taking ``(ratio, desc=...)``. + + Returns: + Dict[str, Any]: Standard payload with generated audio tensors, status, + intermediate outputs, success flag, and optional error text. + + Raises: + No exceptions are re-raised. Runtime failures are converted into the + returned error payload. + """ + progress = self._resolve_generate_music_progress(progress) + if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None: + readiness_error = self._validate_generate_music_readiness() + return readiness_error + + task_type, instruction = self._resolve_generate_music_task( + task_type=task_type, + audio_code_string=audio_code_string, + instruction=instruction, + ) + + logger.info("[generate_music] Starting generation...") + if progress: + progress(0.51, desc="Preparing inputs...") + logger.info("[generate_music] Preparing inputs...") + + runtime = self._prepare_generate_music_runtime( + batch_size=batch_size, + audio_duration=audio_duration, + repainting_end=repainting_end, + seed=seed, + use_random_seed=use_random_seed, + ) + actual_batch_size = runtime["actual_batch_size"] + actual_seed_list = runtime["actual_seed_list"] + seed_value_for_ui = runtime["seed_value_for_ui"] + audio_duration = runtime["audio_duration"] + repainting_end = runtime["repainting_end"] + + try: + refer_audios, processed_src_audio, audio_error = self._prepare_reference_and_source_audio( + reference_audio=reference_audio, + src_audio=src_audio, + audio_code_string=audio_code_string, + actual_batch_size=actual_batch_size, + task_type=task_type, + ) + if audio_error is not None: + return audio_error + + service_inputs = self._prepare_generate_music_service_inputs( + actual_batch_size=actual_batch_size, + processed_src_audio=processed_src_audio, + audio_duration=audio_duration, + captions=captions, + lyrics=lyrics, + vocal_language=vocal_language, + instruction=instruction, + bpm=bpm, + key_scale=key_scale, + time_signature=time_signature, + task_type=task_type, + audio_code_string=audio_code_string, + repainting_start=repainting_start, + repainting_end=repainting_end, + ) + service_run = self._run_generate_music_service_with_progress( + progress=progress, + actual_batch_size=actual_batch_size, + audio_duration=audio_duration, + inference_steps=inference_steps, + timesteps=timesteps, + service_inputs=service_inputs, + refer_audios=refer_audios, + guidance_scale=guidance_scale, + actual_seed_list=actual_seed_list, + audio_cover_strength=audio_cover_strength, + cover_noise_strength=cover_noise_strength, + use_adg=use_adg, + cfg_interval_start=cfg_interval_start, + cfg_interval_end=cfg_interval_end, + shift=shift, + infer_method=infer_method, + ) + outputs = service_run["outputs"] + infer_steps_for_progress = service_run["infer_steps_for_progress"] + + pred_latents, time_costs = self._prepare_generate_music_decode_state( + outputs=outputs, + infer_steps_for_progress=infer_steps_for_progress, + actual_batch_size=actual_batch_size, + audio_duration=audio_duration, + latent_shift=latent_shift, + latent_rescale=latent_rescale, + ) + pred_wavs, pred_latents_cpu, time_costs = self._decode_generate_music_pred_latents( + pred_latents=pred_latents, + progress=progress, + use_tiled_decode=use_tiled_decode, + time_costs=time_costs, + ) + return self._build_generate_music_success_payload( + outputs=outputs, + pred_wavs=pred_wavs, + pred_latents_cpu=pred_latents_cpu, + time_costs=time_costs, + seed_value_for_ui=seed_value_for_ui, + actual_batch_size=actual_batch_size, + progress=progress, + ) + except Exception as exc: + error_msg = f"Error: {exc!s}\n{traceback.format_exc()}" + logger.exception("[generate_music] Generation failed") + return { + "audios": [], + "status_message": error_msg, + "extra_outputs": {}, + "success": False, + "error": f"{exc!s}", + } diff --git a/acestep/core/generation/handler/generate_music_decode.py b/acestep/core/generation/handler/generate_music_decode.py new file mode 100644 index 00000000..79527202 --- /dev/null +++ b/acestep/core/generation/handler/generate_music_decode.py @@ -0,0 +1,196 @@ +"""Decode/validation helpers for ``generate_music`` orchestration.""" + +import os +import time +from typing import Any, Dict, Optional, Tuple + +import torch +from loguru import logger + +from acestep.gpu_config import get_effective_free_vram_gb + + +class GenerateMusicDecodeMixin: + """Validate generated latents and decode them into waveform tensors.""" + + def _prepare_generate_music_decode_state( + self, + outputs: Dict[str, Any], + infer_steps_for_progress: int, + actual_batch_size: int, + audio_duration: Optional[float], + latent_shift: float, + latent_rescale: float, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Collect decode inputs and validate raw diffusion latents. + + Args: + outputs: ``service_generate`` output payload containing target latents and timings. + infer_steps_for_progress: Effective diffusion step count for estimates. + actual_batch_size: Effective generation batch size. + audio_duration: Optional generation duration in seconds. + latent_shift: Additive latent post-processing shift. + latent_rescale: Multiplicative latent post-processing scale. + + Returns: + Tuple containing validated ``pred_latents`` and mutable ``time_costs``. + + Raises: + RuntimeError: If latents contain NaN/Inf values or collapse to all zeros. + """ + logger.info("[generate_music] Model generation completed. Decoding latents...") + pred_latents = outputs["target_latents"] + time_costs = outputs["time_costs"] + time_costs["offload_time_cost"] = self.current_offload_cost + + per_step = time_costs.get("diffusion_per_step_time_cost") + if isinstance(per_step, (int, float)) and per_step > 0: + self._last_diffusion_per_step_sec = float(per_step) + self._update_progress_estimate( + per_step_sec=float(per_step), + infer_steps=infer_steps_for_progress, + batch_size=actual_batch_size, + duration_sec=audio_duration if audio_duration and audio_duration > 0 else None, + ) + + if self.debug_stats: + logger.debug( + f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} " + f"{pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} " + f"{pred_latents.std()=}" + ) + else: + logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype}") + logger.debug(f"[generate_music] time_costs: {time_costs}") + + if torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any(): + raise RuntimeError( + "Generation produced NaN or Inf latents. " + "This usually indicates a checkpoint/config mismatch " + "or unsupported quantization/backend combination. " + "Try running with --backend pt or verify your model checkpoints match this release." + ) + if pred_latents.numel() > 0 and pred_latents.abs().sum() == 0: + raise RuntimeError( + "Generation produced zero latents. " + "This usually indicates a checkpoint/config mismatch or unsupported setup." + ) + if latent_shift != 0.0 or latent_rescale != 1.0: + logger.info( + f"[generate_music] Applying latent post-processing: shift={latent_shift}, " + f"rescale={latent_rescale}" + ) + if self.debug_stats: + logger.debug( + f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, " + f"max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, " + f"std={pred_latents.std():.4f}" + ) + pred_latents = pred_latents * latent_rescale + latent_shift + if self.debug_stats: + logger.debug( + f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, " + f"max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, " + f"std={pred_latents.std():.4f}" + ) + return pred_latents, time_costs + + def _decode_generate_music_pred_latents( + self, + pred_latents: torch.Tensor, + progress: Any, + use_tiled_decode: bool, + time_costs: Dict[str, Any], + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + """Decode predicted latents and update decode timing metrics. + + Args: + pred_latents: Predicted latent tensor shaped ``[batch, frames, dim]``. + progress: Optional progress callback. + use_tiled_decode: Whether tiled VAE decode should be used. + time_costs: Mutable time-cost payload from service generation. + + Returns: + Tuple of decoded waveforms, CPU latents, and updated time-cost payload. + """ + if progress: + progress(0.8, desc="Decoding audio...") + logger.info("[generate_music] Decoding latents with VAE...") + start_time = time.time() + with torch.inference_mode(): + with self._load_model_context("vae"): + pred_latents_cpu = pred_latents.detach().cpu() + pred_latents_for_decode = pred_latents.transpose(1, 2).contiguous().to(self.vae.dtype) + del pred_latents + self._empty_cache() + + logger.debug( + "[generate_music] Before VAE decode: " + f"allocated={self._memory_allocated()/1024**3:.2f}GB, " + f"max={self._max_memory_allocated()/1024**3:.2f}GB" + ) + using_mlx_vae = self.use_mlx_vae and self.mlx_vae is not None + vae_cpu = False + if not using_mlx_vae: + vae_cpu = os.environ.get("ACESTEP_VAE_ON_CPU", "0").lower() in ("1", "true", "yes") + if not vae_cpu: + if self.device == "mps": + logger.info( + "[generate_music] MPS device: skipping VRAM check " + "(unified memory), keeping VAE on MPS" + ) + else: + effective_free = get_effective_free_vram_gb() + logger.info( + "[generate_music] Effective free VRAM before VAE decode: " + f"{effective_free:.2f} GB" + ) + if effective_free < 0.5: + logger.warning( + "[generate_music] Only " + f"{effective_free:.2f} GB free VRAM; auto-enabling CPU VAE decode" + ) + vae_cpu = True + if vae_cpu: + logger.info("[generate_music] Moving VAE to CPU for decode (ACESTEP_VAE_ON_CPU=1)...") + vae_device = next(self.vae.parameters()).device + self.vae = self.vae.cpu() + pred_latents_for_decode = pred_latents_for_decode.cpu() + self._empty_cache() + if use_tiled_decode: + logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") + pred_wavs = self.tiled_decode(pred_latents_for_decode) + elif using_mlx_vae: + try: + pred_wavs = self._mlx_vae_decode(pred_latents_for_decode) + except Exception as exc: + logger.warning( + f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch" + ) + decoder_output = self.vae.decode(pred_latents_for_decode) + pred_wavs = decoder_output.sample + del decoder_output + else: + decoder_output = self.vae.decode(pred_latents_for_decode) + pred_wavs = decoder_output.sample + del decoder_output + if vae_cpu: + logger.info("[generate_music] VAE decode on CPU complete, restoring to GPU...") + self.vae = self.vae.to(vae_device) + logger.debug( + "[generate_music] After VAE decode: " + f"allocated={self._memory_allocated()/1024**3:.2f}GB, " + f"max={self._max_memory_allocated()/1024**3:.2f}GB" + ) + del pred_latents_for_decode + if pred_wavs.dtype != torch.float32: + pred_wavs = pred_wavs.float() + peak = pred_wavs.abs().amax(dim=[1, 2], keepdim=True) + if torch.any(peak > 1.0): + pred_wavs = pred_wavs / peak.clamp(min=1.0) + self._empty_cache() + end_time = time.time() + time_costs["vae_decode_time_cost"] = end_time - start_time + time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"] + time_costs["offload_time_cost"] = self.current_offload_cost + return pred_wavs, pred_latents_cpu, time_costs diff --git a/acestep/core/generation/handler/generate_music_decode_test.py b/acestep/core/generation/handler/generate_music_decode_test.py new file mode 100644 index 00000000..f14bddfe --- /dev/null +++ b/acestep/core/generation/handler/generate_music_decode_test.py @@ -0,0 +1,196 @@ +"""Tests for extracted ``generate_music`` decode helper mixin behavior.""" + +import importlib.util +import types +import sys +import unittest +from contextlib import contextmanager +from pathlib import Path +from unittest.mock import patch + +import torch + + +def _load_generate_music_decode_module(): + """Load ``generate_music_decode.py`` from disk and return its module object. + + Raises ``FileNotFoundError`` or ``ImportError`` when loading fails. + """ + repo_root = Path(__file__).resolve().parents[4] + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + package_paths = { + "acestep": repo_root / "acestep", + "acestep.core": repo_root / "acestep" / "core", + "acestep.core.generation": repo_root / "acestep" / "core" / "generation", + "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", + } + for package_name, package_path in package_paths.items(): + if package_name in sys.modules: + continue + package_module = types.ModuleType(package_name) + package_module.__path__ = [str(package_path)] + sys.modules[package_name] = package_module + module_path = Path(__file__).with_name("generate_music_decode.py") + spec = importlib.util.spec_from_file_location( + "acestep.core.generation.handler.generate_music_decode", + module_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +GENERATE_MUSIC_DECODE_MODULE = _load_generate_music_decode_module() +GenerateMusicDecodeMixin = GENERATE_MUSIC_DECODE_MODULE.GenerateMusicDecodeMixin + + +class _FakeDecodeOutput: + """Minimal VAE decode output container exposing ``sample`` attribute.""" + + def __init__(self, sample: torch.Tensor): + """Store decoded sample tensor for mixin decode flow.""" + self.sample = sample + + +class _FakeVae: + """Minimal VAE stand-in with dtype, decode, and parameter iteration hooks.""" + + def __init__(self): + """Initialize deterministic dtype/device state for decode tests.""" + self.dtype = torch.float32 + self._param = torch.nn.Parameter(torch.zeros(1)) + + def decode(self, latents: torch.Tensor): + """Return deterministic decoded waveform output.""" + return _FakeDecodeOutput(torch.ones(latents.shape[0], 2, 8)) + + def parameters(self): + """Yield one parameter so `.device` lookups remain valid.""" + yield self._param + + def cpu(self): + """Return self for test-only CPU transfer calls.""" + return self + + def to(self, *_args, **_kwargs): + """Return self for test-only device transfer calls.""" + return self + + +class _Host(GenerateMusicDecodeMixin): + """Minimal decode-mixin host exposing deterministic state for assertions.""" + + def __init__(self): + """Initialize deterministic runtime state for decode tests.""" + self.current_offload_cost = 0.25 + self.debug_stats = False + self._last_diffusion_per_step_sec = None + self.estimate_calls = [] + self.progress_calls = [] + self.device = "cpu" + self.use_mlx_vae = True + self.mlx_vae = object() + self.vae = _FakeVae() + + def _update_progress_estimate(self, **kwargs): + """Capture estimate updates for assertions.""" + self.estimate_calls.append(kwargs) + + @contextmanager + def _load_model_context(self, _model_name): + """Provide no-op model context manager for decode tests.""" + yield + + def _empty_cache(self): + """Provide no-op cache clear helper for decode tests.""" + return None + + def _memory_allocated(self): + """Return deterministic allocated-memory value for debug logging.""" + return 0.0 + + def _max_memory_allocated(self): + """Return deterministic max-memory value for debug logging.""" + return 0.0 + + def _mlx_vae_decode(self, latents): + """Return deterministic decoded waveform for MLX decode branch.""" + _ = latents + return torch.ones(1, 2, 8) + + def tiled_decode(self, latents): + """Return deterministic decoded waveform for tiled decode branch.""" + _ = latents + return torch.ones(1, 2, 8) + + +class GenerateMusicDecodeMixinTests(unittest.TestCase): + """Verify decode-state preparation and latent decode helper behavior.""" + + def test_prepare_decode_state_updates_progress_estimates(self): + """It updates timing fields and progress estimate metadata for valid latents.""" + host = _Host() + outputs = { + "target_latents": torch.ones(1, 4, 3), + "time_costs": {"total_time_cost": 1.0, "diffusion_per_step_time_cost": 0.2}, + } + pred_latents, time_costs = host._prepare_generate_music_decode_state( + outputs=outputs, + infer_steps_for_progress=8, + actual_batch_size=1, + audio_duration=12.0, + latent_shift=0.0, + latent_rescale=1.0, + ) + self.assertEqual(tuple(pred_latents.shape), (1, 4, 3)) + self.assertEqual(time_costs["offload_time_cost"], 0.25) + self.assertEqual(host._last_diffusion_per_step_sec, 0.2) + self.assertEqual(host.estimate_calls[0]["infer_steps"], 8) + + def test_prepare_decode_state_raises_for_nan_latents(self): + """It raises runtime error when diffusion latents contain NaN values.""" + host = _Host() + outputs = { + "target_latents": torch.tensor([[[float("nan")]]]), + "time_costs": {"total_time_cost": 1.0}, + } + with self.assertRaises(RuntimeError): + host._prepare_generate_music_decode_state( + outputs=outputs, + infer_steps_for_progress=8, + actual_batch_size=1, + audio_duration=None, + latent_shift=0.0, + latent_rescale=1.0, + ) + + def test_decode_pred_latents_updates_decode_time_and_returns_cpu_latents(self): + """It decodes latents and updates decode timing metrics in time_costs.""" + host = _Host() + pred_latents = torch.ones(1, 4, 3) + time_costs = {"total_time_cost": 1.0} + + def _progress(value, desc=None): + """Capture progress updates for assertions.""" + host.progress_calls.append((value, desc)) + + with patch.object(GENERATE_MUSIC_DECODE_MODULE.time, "time", side_effect=[10.0, 11.5]): + pred_wavs, pred_latents_cpu, updated_costs = host._decode_generate_music_pred_latents( + pred_latents=pred_latents, + progress=_progress, + use_tiled_decode=False, + time_costs=time_costs, + ) + + self.assertEqual(tuple(pred_wavs.shape), (1, 2, 8)) + self.assertEqual(pred_latents_cpu.device.type, "cpu") + self.assertAlmostEqual(updated_costs["vae_decode_time_cost"], 1.5, places=6) + self.assertAlmostEqual(updated_costs["total_time_cost"], 2.5, places=6) + self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6) + self.assertEqual(host.progress_calls[0][0], 0.8) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/core/generation/handler/generate_music_payload.py b/acestep/core/generation/handler/generate_music_payload.py new file mode 100644 index 00000000..d8ee4c7e --- /dev/null +++ b/acestep/core/generation/handler/generate_music_payload.py @@ -0,0 +1,91 @@ +"""Success payload builders for ``generate_music`` orchestration.""" + +from typing import Any, Dict + +from loguru import logger + + +class GenerateMusicPayloadMixin: + """Build audio/metadata payload structures returned by ``generate_music``.""" + + def _build_generate_music_success_payload( + self, + outputs: Dict[str, Any], + pred_wavs, + pred_latents_cpu, + time_costs: Dict[str, Any], + seed_value_for_ui: int, + actual_batch_size: int, + progress: Any, + ) -> Dict[str, Any]: + """Assemble final success response from decoded tensors and model outputs. + + Args: + outputs: Service output payload containing intermediate generation tensors. + pred_wavs: Decoded waveform tensor shaped ``[batch, channels, samples]``. + pred_latents_cpu: CPU latent tensor preserved for extra outputs. + time_costs: Updated time-cost payload including decode/offload timings. + seed_value_for_ui: Seed value displayed in UI outputs. + actual_batch_size: Effective generation batch size. + progress: Optional progress callback. + + Returns: + Dict[str, Any]: Standard success payload returned by ``generate_music``. + """ + logger.info("[generate_music] VAE decode completed. Preparing audio tensors...") + if progress: + progress(0.99, desc="Preparing audio data...") + + audio_tensors = [] + for index in range(actual_batch_size): + audio_tensor = pred_wavs[index].cpu() + audio_tensors.append(audio_tensor) + + status_message = "Generation completed successfully!" + logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.") + + src_latents = outputs.get("src_latents") + target_latents_input = outputs.get("target_latents_input") + chunk_masks = outputs.get("chunk_masks") + spans = outputs.get("spans", []) + latent_masks = outputs.get("latent_masks") + + encoder_hidden_states = outputs.get("encoder_hidden_states") + encoder_attention_mask = outputs.get("encoder_attention_mask") + context_latents = outputs.get("context_latents") + lyric_token_idss = outputs.get("lyric_token_idss") + + extra_outputs = { + "pred_latents": pred_latents_cpu, + "target_latents": target_latents_input.detach().cpu() if target_latents_input is not None else None, + "src_latents": src_latents.detach().cpu() if src_latents is not None else None, + "chunk_masks": chunk_masks.detach().cpu() if chunk_masks is not None else None, + "latent_masks": latent_masks.detach().cpu() if latent_masks is not None else None, + "spans": spans, + "time_costs": time_costs, + "seed_value": seed_value_for_ui, + "encoder_hidden_states": ( + encoder_hidden_states.detach().cpu() + if encoder_hidden_states is not None + else None + ), + "encoder_attention_mask": ( + encoder_attention_mask.detach().cpu() + if encoder_attention_mask is not None + else None + ), + "context_latents": context_latents.detach().cpu() if context_latents is not None else None, + "lyric_token_idss": lyric_token_idss.detach().cpu() if lyric_token_idss is not None else None, + } + + audios = [] + for audio_tensor in audio_tensors: + audios.append({"tensor": audio_tensor, "sample_rate": self.sample_rate}) + + return { + "audios": audios, + "status_message": status_message, + "extra_outputs": extra_outputs, + "success": True, + "error": None, + } diff --git a/acestep/core/generation/handler/generate_music_payload_test.py b/acestep/core/generation/handler/generate_music_payload_test.py new file mode 100644 index 00000000..f8dd66ee --- /dev/null +++ b/acestep/core/generation/handler/generate_music_payload_test.py @@ -0,0 +1,112 @@ +"""Tests for extracted ``generate_music`` success-payload builder behavior. + +The module loads ``acestep.core.generation.handler.generate_music_payload`` +directly from file and validates final payload assembly with deterministic test +fixtures. +""" + +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +def _load_generate_music_payload_module(): + """Load ``generate_music_payload.py`` from disk for isolated tests. + + Returns: + types.ModuleType: Loaded module object for + ``acestep.core.generation.handler.generate_music_payload``. + + Raises: + FileNotFoundError: If the target file does not exist. + ImportError: If module execution fails. + """ + repo_root = Path(__file__).resolve().parents[4] + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + package_paths = { + "acestep": repo_root / "acestep", + "acestep.core": repo_root / "acestep" / "core", + "acestep.core.generation": repo_root / "acestep" / "core" / "generation", + "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", + } + for package_name, package_path in package_paths.items(): + if package_name in sys.modules: + continue + package_module = types.ModuleType(package_name) + package_module.__path__ = [str(package_path)] + sys.modules[package_name] = package_module + module_path = Path(__file__).with_name("generate_music_payload.py") + spec = importlib.util.spec_from_file_location( + "acestep.core.generation.handler.generate_music_payload", + module_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +GENERATE_MUSIC_PAYLOAD_MODULE = _load_generate_music_payload_module() +GenerateMusicPayloadMixin = GENERATE_MUSIC_PAYLOAD_MODULE.GenerateMusicPayloadMixin + + +class _Host(GenerateMusicPayloadMixin): + """Minimal host providing state required by payload assembly tests.""" + + def __init__(self): + """Initialize deterministic sample rate state.""" + self.sample_rate = 48000 + + +class GenerateMusicPayloadMixinTests(unittest.TestCase): + """Verify payload builder output structure and tensor routing.""" + + def test_build_success_payload_contains_audio_and_extra_outputs(self): + """It assembles audios and extra_outputs with CPU tensors and metadata.""" + host = _Host() + outputs = { + "target_latents_input": torch.ones(1, 4, 3), + "src_latents": torch.ones(1, 4, 3), + "chunk_masks": torch.ones(1, 4), + "latent_masks": torch.ones(1, 4), + "spans": [(0, 4)], + "encoder_hidden_states": torch.ones(1, 2, 3), + "encoder_attention_mask": torch.ones(1, 2), + "context_latents": torch.ones(1, 4, 3), + "lyric_token_idss": torch.ones(1, 2, dtype=torch.long), + } + pred_wavs = torch.ones(1, 2, 8) + pred_latents_cpu = torch.ones(1, 4, 3) + time_costs = {"total_time_cost": 2.0} + progress_calls = [] + + def _progress(value, desc=None): + """Capture progress updates for assertions.""" + progress_calls.append((value, desc)) + + payload = host._build_generate_music_success_payload( + outputs=outputs, + pred_wavs=pred_wavs, + pred_latents_cpu=pred_latents_cpu, + time_costs=time_costs, + seed_value_for_ui=7, + actual_batch_size=1, + progress=_progress, + ) + + self.assertTrue(payload["success"]) + self.assertEqual(payload["error"], None) + self.assertEqual(len(payload["audios"]), 1) + self.assertEqual(payload["audios"][0]["sample_rate"], 48000) + self.assertEqual(payload["extra_outputs"]["seed_value"], 7) + self.assertEqual(payload["extra_outputs"]["pred_latents"].device.type, "cpu") + self.assertEqual(progress_calls[0][0], 0.99) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/core/generation/handler/generate_music_test.py b/acestep/core/generation/handler/generate_music_test.py new file mode 100644 index 00000000..78738a0d --- /dev/null +++ b/acestep/core/generation/handler/generate_music_test.py @@ -0,0 +1,197 @@ +"""Tests for extracted ``generate_music`` orchestration behavior. + +The module loads ``acestep.core.generation.handler.generate_music`` directly +from file to avoid package import side effects and validates orchestration +ordering, readiness short-circuiting, and failure payload handling. +""" + +import importlib.util +import sys +import types +import unittest +from pathlib import Path +from typing import Any, Dict + +import torch + + +def _load_generate_music_module(): + """Load ``generate_music.py`` from disk for isolated mixin tests. + + Returns: + types.ModuleType: Loaded module object for + ``acestep.core.generation.handler.generate_music``. + + Raises: + FileNotFoundError: If the target module file is missing. + ImportError: If module loading fails. + """ + repo_root = Path(__file__).resolve().parents[4] + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + package_paths = { + "acestep": repo_root / "acestep", + "acestep.core": repo_root / "acestep" / "core", + "acestep.core.generation": repo_root / "acestep" / "core" / "generation", + "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", + } + for package_name, package_path in package_paths.items(): + if package_name in sys.modules: + continue + package_module = types.ModuleType(package_name) + package_module.__path__ = [str(package_path)] + sys.modules[package_name] = package_module + module_path = Path(__file__).with_name("generate_music.py") + spec = importlib.util.spec_from_file_location( + "acestep.core.generation.handler.generate_music", + module_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +GENERATE_MUSIC_MODULE = _load_generate_music_module() +GenerateMusicMixin = GENERATE_MUSIC_MODULE.GenerateMusicMixin + + +class _Host(GenerateMusicMixin): + """Minimal host implementing ``generate_music`` helper dependencies. + + The host captures helper calls in ``self.calls`` and returns deterministic + payloads so tests can assert orchestration sequencing and return behavior. + """ + + def __init__(self): + """Initialize deterministic state and stub payloads for orchestration tests.""" + self.model = object() + self.vae = object() + self.text_tokenizer = object() + self.text_encoder = object() + self.calls: Dict[str, Any] = {} + self._final_payload = {"audios": [{"tensor": torch.zeros(1, 4), "sample_rate": 48000}], "success": True} + self._readiness_error = { + "audios": [], + "status_message": "not ready", + "extra_outputs": {}, + "success": False, + "error": "Model not fully initialized", + } + + def _resolve_generate_music_progress(self, progress): + """Return provided callback or deterministic no-op callback.""" + self.calls["_resolve_generate_music_progress"] = bool(progress) + if progress is not None: + return progress + + def _noop(*_args, **_kwargs): + """Ignore progress updates in tests.""" + return None + + return _noop + + def _validate_generate_music_readiness(self): + """Return deterministic readiness error payload.""" + self.calls["_validate_generate_music_readiness"] = True + return self._readiness_error + + def _resolve_generate_music_task(self, **kwargs): + """Capture task resolution args and return deterministic task/instruction.""" + self.calls["_resolve_generate_music_task"] = kwargs + return kwargs["task_type"], kwargs["instruction"] + + def _prepare_generate_music_runtime(self, **kwargs): + """Capture runtime args and return deterministic runtime state.""" + self.calls["_prepare_generate_music_runtime"] = kwargs + return { + "actual_batch_size": 1, + "actual_seed_list": [77], + "seed_value_for_ui": 77, + "audio_duration": kwargs["audio_duration"], + "repainting_end": kwargs["repainting_end"], + } + + def _prepare_reference_and_source_audio(self, **kwargs): + """Capture audio-prepare args and return deterministic prepared state.""" + self.calls["_prepare_reference_and_source_audio"] = kwargs + return [[torch.zeros(2, 10)]], None, None + + def _prepare_generate_music_service_inputs(self, **kwargs): + """Capture service-input args and return deterministic payload.""" + self.calls["_prepare_generate_music_service_inputs"] = kwargs + return {"should_return_intermediate": True} + + def _run_generate_music_service_with_progress(self, **kwargs): + """Capture service execution args and return deterministic model outputs.""" + self.calls["_run_generate_music_service_with_progress"] = kwargs + return { + "outputs": { + "target_latents": torch.ones(1, 4, 3), + "time_costs": {"total_time_cost": 1.0, "diffusion_per_step_time_cost": 0.1}, + }, + "infer_steps_for_progress": 8, + } + + def _prepare_generate_music_decode_state(self, **kwargs): + """Capture decode-state args and return deterministic latents/costs.""" + self.calls["_prepare_generate_music_decode_state"] = kwargs + return torch.ones(1, 4, 3), {"total_time_cost": 1.0} + + def _decode_generate_music_pred_latents(self, **kwargs): + """Capture decode args and return deterministic decode outputs.""" + self.calls["_decode_generate_music_pred_latents"] = kwargs + return torch.ones(1, 2, 8), torch.ones(1, 4, 3), {"total_time_cost": 2.0} + + def _build_generate_music_success_payload(self, **kwargs): + """Capture payload-builder args and return deterministic success payload.""" + self.calls["_build_generate_music_success_payload"] = kwargs + return self._final_payload + + +class GenerateMusicMixinTests(unittest.TestCase): + """Verify top-level ``generate_music`` orchestration behavior.""" + + def test_generate_music_returns_success_payload_from_builder(self): + """It executes helper stages and returns the payload builder result.""" + host = _Host() + out = host.generate_music( + captions="cap", + lyrics="lyr", + inference_steps=8, + guidance_scale=6.5, + use_random_seed=False, + seed=77, + task_type="text2music", + ) + self.assertEqual(out, host._final_payload) + self.assertEqual(host.calls["_prepare_generate_music_runtime"]["seed"], 77) + self.assertEqual(host.calls["_run_generate_music_service_with_progress"]["guidance_scale"], 6.5) + self.assertEqual(host.calls["_prepare_generate_music_decode_state"]["infer_steps_for_progress"], 8) + + def test_generate_music_returns_readiness_error_when_components_missing(self): + """It short-circuits with readiness payload when required models are missing.""" + host = _Host() + host.model = None + out = host.generate_music(captions="cap", lyrics="lyr") + self.assertEqual(out, host._readiness_error) + self.assertTrue(host.calls["_validate_generate_music_readiness"]) + self.assertNotIn("_prepare_generate_music_runtime", host.calls) + + def test_generate_music_returns_error_payload_on_exception(self): + """It catches orchestration errors and returns standardized failure payload.""" + host = _Host() + + def _raise_error(**_kwargs): + """Raise deterministic runtime failure for exception-path validation.""" + raise RuntimeError("boom") + + host._prepare_reference_and_source_audio = _raise_error + out = host.generate_music(captions="cap", lyrics="lyr") + self.assertFalse(out["success"]) + self.assertEqual(out["error"], "boom") + self.assertIn("Error: boom", out["status_message"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/handler.py b/acestep/handler.py index 02cf3879..ec43bdf2 100644 --- a/acestep/handler.py +++ b/acestep/handler.py @@ -8,16 +8,12 @@ # Disable tokenizers parallelism to avoid fork warning os.environ["TOKENIZERS_PARALLELISM"] = "false" -import traceback import threading -from typing import Optional, Dict, Any, Tuple, List, Union +from typing import Optional import torch -import time -from loguru import logger import warnings -from acestep.constants import DEFAULT_DIT_INSTRUCTION, SFT_GEN_PROMPT from acestep.core.generation.handler import ( AudioCodesMixin, BatchPrepMixin, @@ -27,7 +23,10 @@ ConditioningTargetMixin, ConditioningTextMixin, DiffusionMixin, + GenerateMusicDecodeMixin, GenerateMusicExecuteMixin, + GenerateMusicMixin, + GenerateMusicPayloadMixin, GenerateMusicRequestMixin, InitServiceMixin, IoAudioMixin, @@ -54,7 +53,6 @@ ServiceGenerateExecuteMixin, ServiceGenerateOutputsMixin, ) -from acestep.gpu_config import get_effective_free_vram_gb warnings.filterwarnings("ignore") @@ -62,6 +60,9 @@ class AceStepHandler( DiffusionMixin, + GenerateMusicMixin, + GenerateMusicDecodeMixin, + GenerateMusicPayloadMixin, GenerateMusicExecuteMixin, GenerateMusicRequestMixin, AudioCodesMixin, @@ -166,342 +167,3 @@ def __init__(self): self.mlx_vae = None self.use_mlx_vae = False - def generate_music( - self, - captions: str, - lyrics: str, - bpm: Optional[int] = None, - key_scale: str = "", - time_signature: str = "", - vocal_language: str = "en", - inference_steps: int = 8, - guidance_scale: float = 7.0, - use_random_seed: bool = True, - seed: Optional[Union[str, float, int]] = -1, - reference_audio=None, - audio_duration: Optional[float] = None, - batch_size: Optional[int] = None, - src_audio=None, - audio_code_string: Union[str, List[str]] = "", - repainting_start: float = 0.0, - repainting_end: Optional[float] = None, - instruction: str = DEFAULT_DIT_INSTRUCTION, - audio_cover_strength: float = 1.0, - cover_noise_strength: float = 0.0, - task_type: str = "text2music", - use_adg: bool = False, - cfg_interval_start: float = 0.0, - cfg_interval_end: float = 1.0, - shift: float = 1.0, - infer_method: str = "ode", - use_tiled_decode: bool = True, - timesteps: Optional[List[float]] = None, - latent_shift: float = 0.0, - latent_rescale: float = 1.0, - progress=None - ) -> Dict[str, Any]: - """ - Main interface for music generation - - Returns: - Dictionary containing: - - audios: List of audio dictionaries with path, key, params - - generation_info: Markdown-formatted generation information - - status_message: Status message - - extra_outputs: Dictionary with latents, masks, time_costs, etc. - - success: Whether generation completed successfully - - error: Error message if generation failed - """ - progress = self._resolve_generate_music_progress(progress) - - if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None: - readiness_error = self._validate_generate_music_readiness() - return readiness_error - - task_type, instruction = self._resolve_generate_music_task( - task_type=task_type, - audio_code_string=audio_code_string, - instruction=instruction, - ) - - logger.info("[generate_music] Starting generation...") - if progress: - progress(0.51, desc="Preparing inputs...") - logger.info("[generate_music] Preparing inputs...") - - runtime = self._prepare_generate_music_runtime( - batch_size=batch_size, - audio_duration=audio_duration, - repainting_end=repainting_end, - seed=seed, - use_random_seed=use_random_seed, - ) - actual_batch_size = runtime["actual_batch_size"] - actual_seed_list = runtime["actual_seed_list"] - seed_value_for_ui = runtime["seed_value_for_ui"] - audio_duration = runtime["audio_duration"] - repainting_end = runtime["repainting_end"] - - try: - refer_audios, processed_src_audio, audio_error = self._prepare_reference_and_source_audio( - reference_audio=reference_audio, - src_audio=src_audio, - audio_code_string=audio_code_string, - actual_batch_size=actual_batch_size, - task_type=task_type, - ) - if audio_error is not None: - return audio_error - - service_inputs = self._prepare_generate_music_service_inputs( - actual_batch_size=actual_batch_size, - processed_src_audio=processed_src_audio, - audio_duration=audio_duration, - captions=captions, - lyrics=lyrics, - vocal_language=vocal_language, - instruction=instruction, - bpm=bpm, - key_scale=key_scale, - time_signature=time_signature, - task_type=task_type, - audio_code_string=audio_code_string, - repainting_start=repainting_start, - repainting_end=repainting_end, - ) - service_run = self._run_generate_music_service_with_progress( - progress=progress, - actual_batch_size=actual_batch_size, - audio_duration=audio_duration, - inference_steps=inference_steps, - timesteps=timesteps, - service_inputs=service_inputs, - refer_audios=refer_audios, - guidance_scale=guidance_scale, - actual_seed_list=actual_seed_list, - audio_cover_strength=audio_cover_strength, - cover_noise_strength=cover_noise_strength, - use_adg=use_adg, - cfg_interval_start=cfg_interval_start, - cfg_interval_end=cfg_interval_end, - shift=shift, - infer_method=infer_method, - ) - outputs = service_run["outputs"] - infer_steps_for_progress = service_run["infer_steps_for_progress"] - - logger.info("[generate_music] Model generation completed. Decoding latents...") - pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim] - time_costs = outputs["time_costs"] - time_costs["offload_time_cost"] = self.current_offload_cost - per_step = time_costs.get("diffusion_per_step_time_cost") - if isinstance(per_step, (int, float)) and per_step > 0: - self._last_diffusion_per_step_sec = float(per_step) - self._update_progress_estimate( - per_step_sec=float(per_step), - infer_steps=infer_steps_for_progress, - batch_size=actual_batch_size, - duration_sec=audio_duration if audio_duration and audio_duration > 0 else None, - ) - if self.debug_stats: - logger.debug( - f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} " - f"{pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}" - ) - else: - logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype}") - logger.debug(f"[generate_music] time_costs: {time_costs}") - - if torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any(): - raise RuntimeError( - "Generation produced NaN or Inf latents. " - "This usually indicates a checkpoint/config mismatch " - "or unsupported quantization/backend combination. " - "Try running with --backend pt or verify your model checkpoints match this release." - ) - if pred_latents.numel() > 0 and pred_latents.abs().sum() == 0: - raise RuntimeError( - "Generation produced zero latents. " - "This usually indicates a checkpoint/config mismatch or unsupported setup." - ) - - if progress: - progress(0.8, desc="Decoding audio...") - logger.info("[generate_music] Decoding latents with VAE...") - - # Apply latent shift and rescale before VAE decode (for anti-clipping control) - if latent_shift != 0.0 or latent_rescale != 1.0: - logger.info(f"[generate_music] Applying latent post-processing: shift={latent_shift}, rescale={latent_rescale}") - if self.debug_stats: - logger.debug(f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") - pred_latents = pred_latents * latent_rescale + latent_shift - if self.debug_stats: - logger.debug(f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") - - # Decode latents to audio - start_time = time.time() - with torch.inference_mode(): - with self._load_model_context("vae"): - # Move pred_latents to CPU early to save VRAM (will be used in extra_outputs later) - pred_latents_cpu = pred_latents.detach().cpu() - - # Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length] - pred_latents_for_decode = pred_latents.transpose(1, 2).contiguous() - # Ensure input is in VAE's dtype - pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype) - - # Release original pred_latents to free VRAM before VAE decode - del pred_latents - self._empty_cache() - - logger.debug(f"[generate_music] Before VAE decode: allocated={self._memory_allocated()/1024**3:.2f}GB, max={self._max_memory_allocated()/1024**3:.2f}GB") - - # When native MLX VAE is active, bypass VRAM checks and CPU - # offload entirely; MLX uses unified memory, not PyTorch VRAM. - _using_mlx_vae = self.use_mlx_vae and self.mlx_vae is not None - _vae_cpu = False - - if not _using_mlx_vae: - # Check effective free VRAM and auto-enable CPU decode if extremely tight - import os as _os - _vae_cpu = _os.environ.get("ACESTEP_VAE_ON_CPU", "0").lower() in ("1", "true", "yes") - if not _vae_cpu: - # MPS (Apple Silicon) uses unified memory; get_effective_free_vram_gb() - # relies on CUDA and always returns 0 on Mac, which would incorrectly - # force VAE decode onto the CPU. Skip the auto-CPU logic for MPS. - if self.device == "mps": - logger.info("[generate_music] MPS device: skipping VRAM check (unified memory), keeping VAE on MPS") - else: - _effective_free = get_effective_free_vram_gb() - logger.info(f"[generate_music] Effective free VRAM before VAE decode: {_effective_free:.2f} GB") - # If less than 0.5 GB free, VAE decode on GPU will almost certainly OOM - if _effective_free < 0.5: - logger.warning(f"[generate_music] Only {_effective_free:.2f} GB free VRAM; auto-enabling CPU VAE decode") - _vae_cpu = True - if _vae_cpu: - logger.info("[generate_music] Moving VAE to CPU for decode (ACESTEP_VAE_ON_CPU=1)...") - _vae_device = next(self.vae.parameters()).device - self.vae = self.vae.cpu() - pred_latents_for_decode = pred_latents_for_decode.cpu() - self._empty_cache() - - if use_tiled_decode: - logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") - pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples] - elif _using_mlx_vae: - # Direct decode via native MLX (no tiling needed) - try: - pred_wavs = self._mlx_vae_decode(pred_latents_for_decode) - except Exception as exc: - logger.warning(f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch") - decoder_output = self.vae.decode(pred_latents_for_decode) - pred_wavs = decoder_output.sample - del decoder_output - else: - decoder_output = self.vae.decode(pred_latents_for_decode) - pred_wavs = decoder_output.sample - del decoder_output - - if _vae_cpu: - logger.info("[generate_music] VAE decode on CPU complete, restoring to GPU...") - self.vae = self.vae.to(_vae_device) - if pred_wavs.device.type != 'cpu': - pass # already on right device - # pred_wavs stays on CPU - fine for audio post-processing - - logger.debug(f"[generate_music] After VAE decode: allocated={self._memory_allocated()/1024**3:.2f}GB, max={self._max_memory_allocated()/1024**3:.2f}GB") - - # Release pred_latents_for_decode after decode - del pred_latents_for_decode - - # Cast output to float32 for audio processing/saving (in-place if possible) - if pred_wavs.dtype != torch.float32: - pred_wavs = pred_wavs.float() - - # Anti-clipping normalization: only scale if peak exceeds [-1, 1]. - peak = pred_wavs.abs().amax(dim=[1, 2], keepdim=True) - if torch.any(peak > 1.0): - pred_wavs = pred_wavs / peak.clamp(min=1.0) - self._empty_cache() - end_time = time.time() - time_costs["vae_decode_time_cost"] = end_time - start_time - time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"] - - # Update offload cost one last time to include VAE offloading - time_costs["offload_time_cost"] = self.current_offload_cost - - logger.info("[generate_music] VAE decode completed. Preparing audio tensors...") - if progress: - progress(0.99, desc="Preparing audio data...") - - # Prepare audio tensors (no file I/O here, no UUID generation) - # pred_wavs is already [batch, channels, samples] format - # Move to CPU and convert to float32 for return - audio_tensors = [] - - for i in range(actual_batch_size): - # Extract audio tensor: [channels, samples] format, CPU, float32 - audio_tensor = pred_wavs[i].cpu() - audio_tensors.append(audio_tensor) - - status_message = "Generation completed successfully!" - logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.") - - # Extract intermediate information from outputs - src_latents = outputs.get("src_latents") # [batch, T, D] - target_latents_input = outputs.get("target_latents_input") # [batch, T, D] - chunk_masks = outputs.get("chunk_masks") # [batch, T] - spans = outputs.get("spans", []) # List of tuples - latent_masks = outputs.get("latent_masks") # [batch, T] - - # Extract condition tensors for LRC timestamp generation - encoder_hidden_states = outputs.get("encoder_hidden_states") - encoder_attention_mask = outputs.get("encoder_attention_mask") - context_latents = outputs.get("context_latents") - lyric_token_idss = outputs.get("lyric_token_idss") - - # Move all tensors to CPU to save VRAM (detach to release computation graph) - extra_outputs = { - "pred_latents": pred_latents_cpu, # Already moved to CPU earlier to save VRAM during VAE decode - "target_latents": target_latents_input.detach().cpu() if target_latents_input is not None else None, - "src_latents": src_latents.detach().cpu() if src_latents is not None else None, - "chunk_masks": chunk_masks.detach().cpu() if chunk_masks is not None else None, - "latent_masks": latent_masks.detach().cpu() if latent_masks is not None else None, - "spans": spans, - "time_costs": time_costs, - "seed_value": seed_value_for_ui, - # Condition tensors for LRC timestamp generation - "encoder_hidden_states": encoder_hidden_states.detach().cpu() if encoder_hidden_states is not None else None, - "encoder_attention_mask": encoder_attention_mask.detach().cpu() if encoder_attention_mask is not None else None, - "context_latents": context_latents.detach().cpu() if context_latents is not None else None, - "lyric_token_idss": lyric_token_idss.detach().cpu() if lyric_token_idss is not None else None, - } - - # Build audios list with tensor data (no file paths, no UUIDs, handled outside) - audios = [] - for idx, audio_tensor in enumerate(audio_tensors): - audio_dict = { - "tensor": audio_tensor, # torch.Tensor [channels, samples], CPU, float32 - "sample_rate": self.sample_rate, - } - audios.append(audio_dict) - - return { - "audios": audios, - "status_message": status_message, - "extra_outputs": extra_outputs, - "success": True, - "error": None, - } - - except Exception as e: - error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" - logger.exception("[generate_music] Generation failed") - return { - "audios": [], - "status_message": error_msg, - "extra_outputs": {}, - "success": False, - "error": str(e), - } - From 91e5d89df7c12c9c2486a069f4dead33c4223ba6 Mon Sep 17 00:00:00 2001 From: Rich Date: Wed, 18 Feb 2026 09:01:30 +0000 Subject: [PATCH 2/2] fix(handler): always restore VAE device after CPU decode path --- .../handler/generate_music_decode.py | 39 ++++++------ .../handler/generate_music_decode_test.py | 59 +++++++++++++++++++ 2 files changed, 81 insertions(+), 17 deletions(-) diff --git a/acestep/core/generation/handler/generate_music_decode.py b/acestep/core/generation/handler/generate_music_decode.py index 79527202..ffaaebbe 100644 --- a/acestep/core/generation/handler/generate_music_decode.py +++ b/acestep/core/generation/handler/generate_music_decode.py @@ -131,6 +131,7 @@ def _decode_generate_music_pred_latents( ) using_mlx_vae = self.use_mlx_vae and self.mlx_vae is not None vae_cpu = False + vae_device = None if not using_mlx_vae: vae_cpu = os.environ.get("ACESTEP_VAE_ON_CPU", "0").lower() in ("1", "true", "yes") if not vae_cpu: @@ -157,26 +158,30 @@ def _decode_generate_music_pred_latents( self.vae = self.vae.cpu() pred_latents_for_decode = pred_latents_for_decode.cpu() self._empty_cache() - if use_tiled_decode: - logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") - pred_wavs = self.tiled_decode(pred_latents_for_decode) - elif using_mlx_vae: - try: - pred_wavs = self._mlx_vae_decode(pred_latents_for_decode) - except Exception as exc: - logger.warning( - f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch" - ) + try: + if use_tiled_decode: + logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") + pred_wavs = self.tiled_decode(pred_latents_for_decode) + elif using_mlx_vae: + try: + pred_wavs = self._mlx_vae_decode(pred_latents_for_decode) + except Exception as exc: + logger.warning( + f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch" + ) + decoder_output = self.vae.decode(pred_latents_for_decode) + pred_wavs = decoder_output.sample + del decoder_output + else: decoder_output = self.vae.decode(pred_latents_for_decode) pred_wavs = decoder_output.sample del decoder_output - else: - decoder_output = self.vae.decode(pred_latents_for_decode) - pred_wavs = decoder_output.sample - del decoder_output - if vae_cpu: - logger.info("[generate_music] VAE decode on CPU complete, restoring to GPU...") - self.vae = self.vae.to(vae_device) + finally: + if vae_cpu and vae_device is not None: + logger.info("[generate_music] Restoring VAE to original device after CPU decode path...") + self.vae = self.vae.to(vae_device) + pred_latents_for_decode = pred_latents_for_decode.to(vae_device) + self._empty_cache() logger.debug( "[generate_music] After VAE decode: " f"allocated={self._memory_allocated()/1024**3:.2f}GB, " diff --git a/acestep/core/generation/handler/generate_music_decode_test.py b/acestep/core/generation/handler/generate_music_decode_test.py index f14bddfe..82cd04fe 100644 --- a/acestep/core/generation/handler/generate_music_decode_test.py +++ b/acestep/core/generation/handler/generate_music_decode_test.py @@ -191,6 +191,65 @@ def _progress(value, desc=None): self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6) self.assertEqual(host.progress_calls[0][0], 0.8) + def test_decode_pred_latents_restores_vae_device_on_decode_error(self): + """It restores VAE device in the CPU-offload path even when decode raises.""" + + class _FailingVae(_FakeVae): + """VAE double that raises during decode and records transfer calls.""" + + def __init__(self): + """Initialize transfer call trackers for restoration assertions.""" + super().__init__() + self.cpu_calls = 0 + self.to_calls = [] + + def decode(self, latents: torch.Tensor): + """Raise decode error to exercise restoration in finally branch.""" + _ = latents + raise RuntimeError("decode failed") + + def cpu(self): + """Record explicit CPU transfer and return self.""" + self.cpu_calls += 1 + return self + + def to(self, *args, **kwargs): + """Record restore transfer target and return self.""" + self.to_calls.append((args, kwargs)) + return self + + class _FailingHost(_Host): + """Host variant that forces non-MLX VAE decode and tracks cache clears.""" + + def __init__(self): + """Set non-MLX state so CPU offload path is exercised deterministically.""" + super().__init__() + self.use_mlx_vae = False + self.mlx_vae = None + self.vae = _FailingVae() + self.empty_cache_calls = 0 + + def _empty_cache(self): + """Count cache-clear calls to verify finally cleanup runs.""" + self.empty_cache_calls += 1 + + host = _FailingHost() + pred_latents = torch.ones(1, 4, 3) + time_costs = {"total_time_cost": 1.0} + + with patch.dict(GENERATE_MUSIC_DECODE_MODULE.os.environ, {"ACESTEP_VAE_ON_CPU": "1"}, clear=False): + with self.assertRaisesRegex(RuntimeError, "decode failed"): + host._decode_generate_music_pred_latents( + pred_latents=pred_latents, + progress=None, + use_tiled_decode=False, + time_costs=time_costs, + ) + + self.assertEqual(host.vae.cpu_calls, 1) + self.assertEqual(len(host.vae.to_calls), 1) + self.assertGreaterEqual(host.empty_cache_calls, 2) + if __name__ == "__main__": unittest.main()