Skip to content

Refactor(handler): decompose generate_music orchestration (part 18)#626

Merged
ChuxiJ merged 2 commits intoace-step:mainfrom
1larity:feat/handler-decomp-part-18
Feb 18, 2026
Merged

Refactor(handler): decompose generate_music orchestration (part 18)#626
ChuxiJ merged 2 commits intoace-step:mainfrom
1larity:feat/handler-decomp-part-18

Conversation

@1larity
Copy link
Contributor

@1larity 1larity commented Feb 17, 2026

Summary

  • Extract generate_music orchestration from acestep/handler.py into focused mixins.
  • Keep public AceStepHandler.generate_music behavior by composing new mixins in handler inheritance.
  • Restore part-17 MLX decomposition on this branch via cherry-picked commits, without PR stacking.

Changes

  • Added generate_music orchestration mixin module.
  • Added decode/validation mixin module for latent checks and VAE decode flow.
  • Added success-payload mixin module for response assembly.
  • Updated handler decomposition exports and AceStepHandler mixin composition.
  • Removed inlined generate_music implementation from acestep/handler.py.
  • Added focused unit tests for orchestration, decode helpers, and payload assembly.

Validation

  • python acestep/core/generation/handler/generate_music_test.py
  • python acestep/core/generation/handler/generate_music_decode_test.py
  • python acestep/core/generation/handler/generate_music_payload_test.py
  • python acestep/core/generation/handler/mlx_dit_init_test.py
  • python acestep/core/generation/handler/mlx_vae_init_test.py
  • python acestep/core/generation/handler/mlx_vae_native_test.py

Summary by CodeRabbit

  • New Features

    • Enhanced music generation: richer configuration (tempo, key, structure, decoding options), end-to-end orchestration, progress milestones, and clearer error reporting.
  • Improved Decoding

    • More reliable audio decoding with normalization and memory-aware strategies for consistent waveform outputs.
  • Tests

    • Added comprehensive tests for orchestration, decoding, and final payload assembly.
  • Refactor

    • Music-generation logic split into focused components for orchestration, decoding, and payload construction.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 17, 2026

📝 Walkthrough

Walkthrough

Extracts the previous monolithic generate_music implementation into three new mixins (orchestration, decoding, payload assembly), adds unit tests for each mixin, updates handler exports and AceStepHandler inheritance to include the mixins, and removes the original generate_music method from acestep/handler.py.

Changes

Cohort / File(s) Summary
Handler Module Exports
acestep/core/generation/handler/__init__.py
Add exports for GenerateMusicMixin, GenerateMusicDecodeMixin, and GenerateMusicPayloadMixin.
Orchestration Mixin
acestep/core/generation/handler/generate_music.py
New GenerateMusicMixin implementing generate_music(...) to validate readiness, prepare inputs, call service, orchestrate decoding, handle progress/errors, and build final payload.
Decoding Mixin
acestep/core/generation/handler/generate_music_decode.py
New GenerateMusicDecodeMixin with _prepare_generate_music_decode_state and _decode_generate_music_pred_latents for latent validation, optional post-processing, VRAM-aware VAE decode (tiled/MLX/CPU fallback), time accounting, and progress hooks.
Decoding Tests
acestep/core/generation/handler/generate_music_decode_test.py
Unit tests exercising decode helpers: progress updates, NaN handling, decode path, error/cleanup behavior and time accounting.
Payload Mixin
acestep/core/generation/handler/generate_music_payload.py
New GenerateMusicPayloadMixin with _build_generate_music_success_payload assembling audios and extra_outputs (move tensors to CPU, include metadata, seed/time_costs).
Payload Tests
acestep/core/generation/handler/generate_music_payload_test.py
Unit test verifying payload shape, sample_rate routing, seed inclusion, CPU placement of latents, and progress callback invocation.
Orchestration Tests
acestep/core/generation/handler/generate_music_test.py
Unit tests for orchestration flow: success path, readiness short-circuit, and exception → error payload.
Handler Refactor
acestep/handler.py
Removed previous generate_music method; add GenerateMusicMixin, GenerateMusicDecodeMixin, and GenerateMusicPayloadMixin to AceStepHandler bases; update imports (removed some constant imports).

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant Handler as AceStepHandler\n(GenerateMusicMixin)
    participant Service as GenerationService
    participant Decoder as GenerateMusicDecodeMixin/VAE
    participant Payload as GenerateMusicPayloadMixin

    Client->>Handler: generate_music(captions, lyrics, params)
    Handler->>Handler: validate readiness\nprepare inputs & seeds
    Handler->>Service: run_service(inputs, progress_cb)
    Service-->>Handler: outputs, latents, metadata
    Handler->>Decoder: _prepare_generate_music_decode_state(outputs,...)
    Decoder->>Decoder: validate & post-process latents
    Decoder->>Decoder: decode latents -> waveforms
    Decoder-->>Handler: pred_wavs, pred_latents_cpu, time_costs
    Handler->>Payload: _build_generate_music_success_payload(outputs, pred_wavs, ...)
    Payload-->>Handler: success payload
    Handler-->>Client: payload
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • ChuxiJ

Poem

🐰 I hopped through code to split the song,
Three little mixins hum along,
Orchestrate, decode, then pack the tune,
Tests all pass beneath the moon,
I leave fresh paths for devs to croon. 🎶

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: decomposing the generate_music orchestration from the handler into focused mixin modules while maintaining the public API through composition.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@1larity 1larity changed the title refactor(handler): decompose generate_music orchestration (part 18) Refactor(handler): decompose generate_music orchestration (part 18) Feb 17, 2026
@1larity 1larity force-pushed the feat/handler-decomp-part-18 branch from 7d96537 to daa0e17 Compare February 17, 2026 22:21
@1larity 1larity marked this pull request as ready for review February 17, 2026 22:25
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
acestep/core/generation/handler/generate_music.py (1)

22-55: Add type hints for untyped parameters in the public entry point.

This keeps the public API consistent with the repository’s typing guidance.

Suggested update
-from typing import Any, Dict, List, Optional, Union
+from collections.abc import Callable
+from typing import Any, Dict, List, Optional, Union
@@
-        reference_audio=None,
+        reference_audio: Optional[Any] = None,
@@
-        src_audio=None,
+        src_audio: Optional[Any] = None,
@@
-        progress=None,
+        progress: Optional[Callable[..., Any]] = None,

As per coding guidelines, "Add type hints for new/modified functions when practical in Python".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music.py` around lines 22 - 55, The
public entry point generate_music has several untyped parameters; add explicit
type hints for reference_audio, src_audio, and progress (and any other
parameters lacking concrete types) to match the repository typing guidance—e.g.,
annotate reference_audio and src_audio with appropriate Optional[Any] or a more
specific AudioType, and annotate progress with Optional[ProgressCallback] or
Optional[Callable[[...], None]] (or the project's progress interface); update
the function signature in generate_music accordingly and ensure any
imports/types used for these annotations are added or referenced so the
signature is fully typed while preserving the existing defaults and return type
Dict[str, Any].
acestep/core/generation/handler/generate_music_payload.py (1)

11-20: Add explicit type hints for tensor/progress parameters.

This keeps the new mixin API self-documenting and aligned with the repo’s typing expectations.

Suggested update
-from typing import Any, Dict
+from collections.abc import Callable
+from typing import Any, Dict, Optional, TYPE_CHECKING
+
+if TYPE_CHECKING:
+    import torch
@@
-        pred_wavs,
-        pred_latents_cpu,
+        pred_wavs: "torch.Tensor",
+        pred_latents_cpu: "torch.Tensor",
@@
-        progress: Any,
+        progress: Optional[Callable[..., Any]],

As per coding guidelines, "Add type hints for new/modified functions when practical in Python".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music_payload.py` around lines 11 -
20, The function _build_generate_music_success_payload should add explicit type
hints for the tensor and progress parameters: annotate pred_wavs and
pred_latents_cpu as torch.Tensor (or Optional[torch.Tensor] if they can be None)
and annotate progress with a concrete type such as tqdm.std.tqdm or
Optional[tqdm.std.tqdm] (or a Callable progress callback type if that matches
usage); update the signature to import and use typing (Optional/Union) and
torch/tqdm as needed and adjust any callers or imports accordingly to keep the
module type-correct.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@acestep/core/generation/handler/generate_music_payload.py`:
- Around line 11-20: The function _build_generate_music_success_payload should
add explicit type hints for the tensor and progress parameters: annotate
pred_wavs and pred_latents_cpu as torch.Tensor (or Optional[torch.Tensor] if
they can be None) and annotate progress with a concrete type such as
tqdm.std.tqdm or Optional[tqdm.std.tqdm] (or a Callable progress callback type
if that matches usage); update the signature to import and use typing
(Optional/Union) and torch/tqdm as needed and adjust any callers or imports
accordingly to keep the module type-correct.

In `@acestep/core/generation/handler/generate_music.py`:
- Around line 22-55: The public entry point generate_music has several untyped
parameters; add explicit type hints for reference_audio, src_audio, and progress
(and any other parameters lacking concrete types) to match the repository typing
guidance—e.g., annotate reference_audio and src_audio with appropriate
Optional[Any] or a more specific AudioType, and annotate progress with
Optional[ProgressCallback] or Optional[Callable[[...], None]] (or the project's
progress interface); update the function signature in generate_music accordingly
and ensure any imports/types used for these annotations are added or referenced
so the signature is fully typed while preserving the existing defaults and
return type Dict[str, Any].

@1larity 1larity force-pushed the feat/handler-decomp-part-18 branch from daa0e17 to ca3fd1b Compare February 18, 2026 08:48
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/core/generation/handler/generate_music_decode.py`:
- Around line 154-159: The VAE is moved to CPU when vae_cpu is true (see the
vae_cpu branch that sets vae_device = next(self.vae.parameters()).device,
self.vae = self.vae.cpu(), pred_latents_for_decode =
pred_latents_for_decode.cpu(), and calls self._empty_cache()), but if the
subsequent decode raises the original device restoration (which uses vae_device)
will be skipped; wrap the decode logic in a try/finally: record the original
device via vae_device before moving the VAE, move to CPU as now, run the decode
in the try block, and in the finally always restore self.vae back to vae_device
(and move pred_latents_for_decode back if needed) and call self._empty_cache()
so the VAE is returned to its original device even on exceptions.

@ChuxiJ ChuxiJ merged commit 2af5382 into ace-step:main Feb 18, 2026
2 of 3 checks passed
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/core/generation/handler/generate_music_decode_test.py`:
- Around line 1-255: The test file is over the 200 LOC cap—extract the reusable
test helpers and fakes into a new helper module (e.g.,
tests/helpers/generate_music_fixtures.py) and import them into this test;
specifically move _load_generate_music_decode_module, _FakeDecodeOutput,
_FakeVae, _Host and any package-setup code (the package_paths/sys.path
manipulation) into that helper module, leaving only the test class
GenerateMusicDecodeMixinTests and its test methods plus minimal imports in
generate_music_decode_test.py; update generate_music_decode_test.py to import
GenerateMusicDecodeMixin (or the helper-provided fixtures) and reference the
same symbols so the tests run unchanged and the file size falls under the 200
LOC cap.
- Around line 14-42: The helper _load_generate_music_decode_module currently
injects stub packages into sys.modules and prepends repo_root to sys.path but
never cleans them up; add module-level trackers (e.g., a list inserted_sys_path
and a list injected_packages) and record the repo_root string and each
package_name you add from package_paths when creating package_module, then
implement tearDownModule() to undo those changes: remove the repo_root entry
from sys.path if it was inserted, and for each package_name in injected_packages
remove it from sys.modules (only if the value is the stub you created), ensuring
you don't disturb real packages; reference _load_generate_music_decode_module,
package_paths, sys.path, sys.modules, and implement tearDownModule to perform
the cleanup after tests.

In `@acestep/core/generation/handler/generate_music_decode.py`:
- Around line 180-184: The code currently moves pred_latents_for_decode back to
the GPU during the CPU-offload cleanup which can OOM; in the restore block
inside the CPU decode path (the section checking vae_cpu and vae_device), only
move self.vae back to vae_device and do NOT call
pred_latents_for_decode.to(vae_device). Remove the line that re-allocates
pred_latents_for_decode on the VAE restore path (referencing
pred_latents_for_decode and self.vae) so the large latent tensor remains on CPU
until it is deleted, then call self._empty_cache() as before.
- Around line 1-201: The module is over the 200-LOC cap due to a large
_decode_generate_music_pred_latents implementation; extract the CPU/VRAM
selection and the post-decode normalization into two helper methods to reduce
size. Specifically, move the logic that determines using_mlx_vae / vae_cpu /
vae_device and the code that moves the VAE to CPU and restores it into a new
private helper (e.g., _select_vae_device_for_decode(self,
pred_latents_for_decode) -> Tuple[pred_latents_for_decode, vae_cpu, vae_device])
and move the peak clamping + dtype coercion into another helper (e.g.,
_post_decode_normalize(self, pred_wavs) -> pred_wavs), update
_decode_generate_music_pred_latents to call these helpers, preserve all side
effects (self.vae device moves, calls to self._empty_cache(), and updates to
time_costs) and keep existing logging and exception fallback behavior intact.
Ensure any new helpers are placed in the same class and imported symbols remain
unchanged so callers and tests need no updates.

Comment on lines +1 to +255
"""Tests for extracted ``generate_music`` decode helper mixin behavior."""

import importlib.util
import types
import sys
import unittest
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import patch

import torch


def _load_generate_music_decode_module():
"""Load ``generate_music_decode.py`` from disk and return its module object.

Raises ``FileNotFoundError`` or ``ImportError`` when loading fails.
"""
repo_root = Path(__file__).resolve().parents[4]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
package_paths = {
"acestep": repo_root / "acestep",
"acestep.core": repo_root / "acestep" / "core",
"acestep.core.generation": repo_root / "acestep" / "core" / "generation",
"acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler",
}
for package_name, package_path in package_paths.items():
if package_name in sys.modules:
continue
package_module = types.ModuleType(package_name)
package_module.__path__ = [str(package_path)]
sys.modules[package_name] = package_module
module_path = Path(__file__).with_name("generate_music_decode.py")
spec = importlib.util.spec_from_file_location(
"acestep.core.generation.handler.generate_music_decode",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module


GENERATE_MUSIC_DECODE_MODULE = _load_generate_music_decode_module()
GenerateMusicDecodeMixin = GENERATE_MUSIC_DECODE_MODULE.GenerateMusicDecodeMixin


class _FakeDecodeOutput:
"""Minimal VAE decode output container exposing ``sample`` attribute."""

def __init__(self, sample: torch.Tensor):
"""Store decoded sample tensor for mixin decode flow."""
self.sample = sample


class _FakeVae:
"""Minimal VAE stand-in with dtype, decode, and parameter iteration hooks."""

def __init__(self):
"""Initialize deterministic dtype/device state for decode tests."""
self.dtype = torch.float32
self._param = torch.nn.Parameter(torch.zeros(1))

def decode(self, latents: torch.Tensor):
"""Return deterministic decoded waveform output."""
return _FakeDecodeOutput(torch.ones(latents.shape[0], 2, 8))

def parameters(self):
"""Yield one parameter so `.device` lookups remain valid."""
yield self._param

def cpu(self):
"""Return self for test-only CPU transfer calls."""
return self

def to(self, *_args, **_kwargs):
"""Return self for test-only device transfer calls."""
return self


class _Host(GenerateMusicDecodeMixin):
"""Minimal decode-mixin host exposing deterministic state for assertions."""

def __init__(self):
"""Initialize deterministic runtime state for decode tests."""
self.current_offload_cost = 0.25
self.debug_stats = False
self._last_diffusion_per_step_sec = None
self.estimate_calls = []
self.progress_calls = []
self.device = "cpu"
self.use_mlx_vae = True
self.mlx_vae = object()
self.vae = _FakeVae()

def _update_progress_estimate(self, **kwargs):
"""Capture estimate updates for assertions."""
self.estimate_calls.append(kwargs)

@contextmanager
def _load_model_context(self, _model_name):
"""Provide no-op model context manager for decode tests."""
yield

def _empty_cache(self):
"""Provide no-op cache clear helper for decode tests."""
return None

def _memory_allocated(self):
"""Return deterministic allocated-memory value for debug logging."""
return 0.0

def _max_memory_allocated(self):
"""Return deterministic max-memory value for debug logging."""
return 0.0

def _mlx_vae_decode(self, latents):
"""Return deterministic decoded waveform for MLX decode branch."""
_ = latents
return torch.ones(1, 2, 8)

def tiled_decode(self, latents):
"""Return deterministic decoded waveform for tiled decode branch."""
_ = latents
return torch.ones(1, 2, 8)


class GenerateMusicDecodeMixinTests(unittest.TestCase):
"""Verify decode-state preparation and latent decode helper behavior."""

def test_prepare_decode_state_updates_progress_estimates(self):
"""It updates timing fields and progress estimate metadata for valid latents."""
host = _Host()
outputs = {
"target_latents": torch.ones(1, 4, 3),
"time_costs": {"total_time_cost": 1.0, "diffusion_per_step_time_cost": 0.2},
}
pred_latents, time_costs = host._prepare_generate_music_decode_state(
outputs=outputs,
infer_steps_for_progress=8,
actual_batch_size=1,
audio_duration=12.0,
latent_shift=0.0,
latent_rescale=1.0,
)
self.assertEqual(tuple(pred_latents.shape), (1, 4, 3))
self.assertEqual(time_costs["offload_time_cost"], 0.25)
self.assertEqual(host._last_diffusion_per_step_sec, 0.2)
self.assertEqual(host.estimate_calls[0]["infer_steps"], 8)

def test_prepare_decode_state_raises_for_nan_latents(self):
"""It raises runtime error when diffusion latents contain NaN values."""
host = _Host()
outputs = {
"target_latents": torch.tensor([[[float("nan")]]]),
"time_costs": {"total_time_cost": 1.0},
}
with self.assertRaises(RuntimeError):
host._prepare_generate_music_decode_state(
outputs=outputs,
infer_steps_for_progress=8,
actual_batch_size=1,
audio_duration=None,
latent_shift=0.0,
latent_rescale=1.0,
)

def test_decode_pred_latents_updates_decode_time_and_returns_cpu_latents(self):
"""It decodes latents and updates decode timing metrics in time_costs."""
host = _Host()
pred_latents = torch.ones(1, 4, 3)
time_costs = {"total_time_cost": 1.0}

def _progress(value, desc=None):
"""Capture progress updates for assertions."""
host.progress_calls.append((value, desc))

with patch.object(GENERATE_MUSIC_DECODE_MODULE.time, "time", side_effect=[10.0, 11.5]):
pred_wavs, pred_latents_cpu, updated_costs = host._decode_generate_music_pred_latents(
pred_latents=pred_latents,
progress=_progress,
use_tiled_decode=False,
time_costs=time_costs,
)

self.assertEqual(tuple(pred_wavs.shape), (1, 2, 8))
self.assertEqual(pred_latents_cpu.device.type, "cpu")
self.assertAlmostEqual(updated_costs["vae_decode_time_cost"], 1.5, places=6)
self.assertAlmostEqual(updated_costs["total_time_cost"], 2.5, places=6)
self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6)
self.assertEqual(host.progress_calls[0][0], 0.8)

def test_decode_pred_latents_restores_vae_device_on_decode_error(self):
"""It restores VAE device in the CPU-offload path even when decode raises."""

class _FailingVae(_FakeVae):
"""VAE double that raises during decode and records transfer calls."""

def __init__(self):
"""Initialize transfer call trackers for restoration assertions."""
super().__init__()
self.cpu_calls = 0
self.to_calls = []

def decode(self, latents: torch.Tensor):
"""Raise decode error to exercise restoration in finally branch."""
_ = latents
raise RuntimeError("decode failed")

def cpu(self):
"""Record explicit CPU transfer and return self."""
self.cpu_calls += 1
return self

def to(self, *args, **kwargs):
"""Record restore transfer target and return self."""
self.to_calls.append((args, kwargs))
return self

class _FailingHost(_Host):
"""Host variant that forces non-MLX VAE decode and tracks cache clears."""

def __init__(self):
"""Set non-MLX state so CPU offload path is exercised deterministically."""
super().__init__()
self.use_mlx_vae = False
self.mlx_vae = None
self.vae = _FailingVae()
self.empty_cache_calls = 0

def _empty_cache(self):
"""Count cache-clear calls to verify finally cleanup runs."""
self.empty_cache_calls += 1

host = _FailingHost()
pred_latents = torch.ones(1, 4, 3)
time_costs = {"total_time_cost": 1.0}

with patch.dict(GENERATE_MUSIC_DECODE_MODULE.os.environ, {"ACESTEP_VAE_ON_CPU": "1"}, clear=False):
with self.assertRaisesRegex(RuntimeError, "decode failed"):
host._decode_generate_music_pred_latents(
pred_latents=pred_latents,
progress=None,
use_tiled_decode=False,
time_costs=time_costs,
)

self.assertEqual(host.vae.cpu_calls, 1)
self.assertEqual(len(host.vae.to_calls), 1)
self.assertGreaterEqual(host.empty_cache_calls, 2)


if __name__ == "__main__":
unittest.main()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Module exceeds 200 LOC cap.

This test module is ~255 lines, which is over the 200 LOC hard cap. Please split the fakes/loader helpers into a smaller helper module or shared fixture to bring it under the limit. As per coding guidelines: Target module size: optimal <= 150 LOC, hard cap 200 LOC.

🧰 Tools
🪛 Ruff (0.15.1)

[warning] 209-209: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music_decode_test.py` around lines 1
- 255, The test file is over the 200 LOC cap—extract the reusable test helpers
and fakes into a new helper module (e.g.,
tests/helpers/generate_music_fixtures.py) and import them into this test;
specifically move _load_generate_music_decode_module, _FakeDecodeOutput,
_FakeVae, _Host and any package-setup code (the package_paths/sys.path
manipulation) into that helper module, leaving only the test class
GenerateMusicDecodeMixinTests and its test methods plus minimal imports in
generate_music_decode_test.py; update generate_music_decode_test.py to import
GenerateMusicDecodeMixin (or the helper-provided fixtures) and reference the
same symbols so the tests run unchanged and the file size falls under the 200
LOC cap.

Comment on lines +14 to +42
def _load_generate_music_decode_module():
"""Load ``generate_music_decode.py`` from disk and return its module object.

Raises ``FileNotFoundError`` or ``ImportError`` when loading fails.
"""
repo_root = Path(__file__).resolve().parents[4]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
package_paths = {
"acestep": repo_root / "acestep",
"acestep.core": repo_root / "acestep" / "core",
"acestep.core.generation": repo_root / "acestep" / "core" / "generation",
"acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler",
}
for package_name, package_path in package_paths.items():
if package_name in sys.modules:
continue
package_module = types.ModuleType(package_name)
package_module.__path__ = [str(package_path)]
sys.modules[package_name] = package_module
module_path = Path(__file__).with_name("generate_music_decode.py")
spec = importlib.util.spec_from_file_location(
"acestep.core.generation.handler.generate_music_decode",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

rg -n -C2 --type=py -g '*_test.py' -e '^\s*(import|from)\s+acestep\b'

Repository: ace-step/ACE-Step-1.5

Length of output: 15331


🏁 Script executed:

# Check the actual test file to see if tearDownModule already exists and how the function is used
cat -n acestep/core/generation/handler/generate_music_decode_test.py

Repository: ace-step/ACE-Step-1.5

Length of output: 11481


🏁 Script executed:

# Also check if there's any cleanup in conftest.py or other test setup files
find . -name "conftest.py" -o -name "test_*.py" | head -20

Repository: ace-step/ACE-Step-1.5

Length of output: 47


Clean up stub packages and sys.path modifications after module tests complete.

_load_generate_music_decode_module() injects package stubs and modifies sys.path at module import time (line 45) without cleanup. In a shared unittest process, stub acestep packages can mask the real package for other tests that follow. Confirm: 20+ other test files import from acestep, making contamination possible.

Add module-level tracking lists and a tearDownModule() function to remove injected entries:

Cleanup pattern
+_TEMP_SYS_MODULES = []
+_TEMP_SYS_PATHS = []
+
 def _load_generate_music_decode_module():
     """Load ``generate_music_decode.py`` from disk and return its module object.
@@
     repo_root = Path(__file__).resolve().parents[4]
     if str(repo_root) not in sys.path:
         sys.path.insert(0, str(repo_root))
+        _TEMP_SYS_PATHS.append(str(repo_root))
@@
     for package_name, package_path in package_paths.items():
         if package_name in sys.modules:
             continue
         package_module = types.ModuleType(package_name)
         package_module.__path__ = [str(package_path)]
         sys.modules[package_name] = package_module
+        _TEMP_SYS_MODULES.append(package_name)
@@
     spec.loader.exec_module(module)
     return module
+
+
+def tearDownModule():
+    for name in _TEMP_SYS_MODULES:
+        sys.modules.pop(name, None)
+    for path in _TEMP_SYS_PATHS:
+        if path in sys.path:
+            sys.path.remove(path)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _load_generate_music_decode_module():
"""Load ``generate_music_decode.py`` from disk and return its module object.
Raises ``FileNotFoundError`` or ``ImportError`` when loading fails.
"""
repo_root = Path(__file__).resolve().parents[4]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
package_paths = {
"acestep": repo_root / "acestep",
"acestep.core": repo_root / "acestep" / "core",
"acestep.core.generation": repo_root / "acestep" / "core" / "generation",
"acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler",
}
for package_name, package_path in package_paths.items():
if package_name in sys.modules:
continue
package_module = types.ModuleType(package_name)
package_module.__path__ = [str(package_path)]
sys.modules[package_name] = package_module
module_path = Path(__file__).with_name("generate_music_decode.py")
spec = importlib.util.spec_from_file_location(
"acestep.core.generation.handler.generate_music_decode",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
_TEMP_SYS_MODULES = []
_TEMP_SYS_PATHS = []
def _load_generate_music_decode_module():
"""Load ``generate_music_decode.py`` from disk and return its module object.
Raises ``FileNotFoundError`` or ``ImportError`` when loading fails.
"""
repo_root = Path(__file__).resolve().parents[4]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
_TEMP_SYS_PATHS.append(str(repo_root))
package_paths = {
"acestep": repo_root / "acestep",
"acestep.core": repo_root / "acestep" / "core",
"acestep.core.generation": repo_root / "acestep" / "core" / "generation",
"acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler",
}
for package_name, package_path in package_paths.items():
if package_name in sys.modules:
continue
package_module = types.ModuleType(package_name)
package_module.__path__ = [str(package_path)]
sys.modules[package_name] = package_module
_TEMP_SYS_MODULES.append(package_name)
module_path = Path(__file__).with_name("generate_music_decode.py")
spec = importlib.util.spec_from_file_location(
"acestep.core.generation.handler.generate_music_decode",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def tearDownModule():
for name in _TEMP_SYS_MODULES:
sys.modules.pop(name, None)
for path in _TEMP_SYS_PATHS:
if path in sys.path:
sys.path.remove(path)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music_decode_test.py` around lines
14 - 42, The helper _load_generate_music_decode_module currently injects stub
packages into sys.modules and prepends repo_root to sys.path but never cleans
them up; add module-level trackers (e.g., a list inserted_sys_path and a list
injected_packages) and record the repo_root string and each package_name you add
from package_paths when creating package_module, then implement tearDownModule()
to undo those changes: remove the repo_root entry from sys.path if it was
inserted, and for each package_name in injected_packages remove it from
sys.modules (only if the value is the stub you created), ensuring you don't
disturb real packages; reference _load_generate_music_decode_module,
package_paths, sys.path, sys.modules, and implement tearDownModule to perform
the cleanup after tests.

Comment on lines +1 to +201
"""Decode/validation helpers for ``generate_music`` orchestration."""

import os
import time
from typing import Any, Dict, Optional, Tuple

import torch
from loguru import logger

from acestep.gpu_config import get_effective_free_vram_gb


class GenerateMusicDecodeMixin:
"""Validate generated latents and decode them into waveform tensors."""

def _prepare_generate_music_decode_state(
self,
outputs: Dict[str, Any],
infer_steps_for_progress: int,
actual_batch_size: int,
audio_duration: Optional[float],
latent_shift: float,
latent_rescale: float,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Collect decode inputs and validate raw diffusion latents.

Args:
outputs: ``service_generate`` output payload containing target latents and timings.
infer_steps_for_progress: Effective diffusion step count for estimates.
actual_batch_size: Effective generation batch size.
audio_duration: Optional generation duration in seconds.
latent_shift: Additive latent post-processing shift.
latent_rescale: Multiplicative latent post-processing scale.

Returns:
Tuple containing validated ``pred_latents`` and mutable ``time_costs``.

Raises:
RuntimeError: If latents contain NaN/Inf values or collapse to all zeros.
"""
logger.info("[generate_music] Model generation completed. Decoding latents...")
pred_latents = outputs["target_latents"]
time_costs = outputs["time_costs"]
time_costs["offload_time_cost"] = self.current_offload_cost

per_step = time_costs.get("diffusion_per_step_time_cost")
if isinstance(per_step, (int, float)) and per_step > 0:
self._last_diffusion_per_step_sec = float(per_step)
self._update_progress_estimate(
per_step_sec=float(per_step),
infer_steps=infer_steps_for_progress,
batch_size=actual_batch_size,
duration_sec=audio_duration if audio_duration and audio_duration > 0 else None,
)

if self.debug_stats:
logger.debug(
f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} "
f"{pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} "
f"{pred_latents.std()=}"
)
else:
logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype}")
logger.debug(f"[generate_music] time_costs: {time_costs}")

if torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any():
raise RuntimeError(
"Generation produced NaN or Inf latents. "
"This usually indicates a checkpoint/config mismatch "
"or unsupported quantization/backend combination. "
"Try running with --backend pt or verify your model checkpoints match this release."
)
if pred_latents.numel() > 0 and pred_latents.abs().sum() == 0:
raise RuntimeError(
"Generation produced zero latents. "
"This usually indicates a checkpoint/config mismatch or unsupported setup."
)
if latent_shift != 0.0 or latent_rescale != 1.0:
logger.info(
f"[generate_music] Applying latent post-processing: shift={latent_shift}, "
f"rescale={latent_rescale}"
)
if self.debug_stats:
logger.debug(
f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, "
f"max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, "
f"std={pred_latents.std():.4f}"
)
pred_latents = pred_latents * latent_rescale + latent_shift
if self.debug_stats:
logger.debug(
f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, "
f"max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, "
f"std={pred_latents.std():.4f}"
)
return pred_latents, time_costs

def _decode_generate_music_pred_latents(
self,
pred_latents: torch.Tensor,
progress: Any,
use_tiled_decode: bool,
time_costs: Dict[str, Any],
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""Decode predicted latents and update decode timing metrics.

Args:
pred_latents: Predicted latent tensor shaped ``[batch, frames, dim]``.
progress: Optional progress callback.
use_tiled_decode: Whether tiled VAE decode should be used.
time_costs: Mutable time-cost payload from service generation.

Returns:
Tuple of decoded waveforms, CPU latents, and updated time-cost payload.
"""
if progress:
progress(0.8, desc="Decoding audio...")
logger.info("[generate_music] Decoding latents with VAE...")
start_time = time.time()
with torch.inference_mode():
with self._load_model_context("vae"):
pred_latents_cpu = pred_latents.detach().cpu()
pred_latents_for_decode = pred_latents.transpose(1, 2).contiguous().to(self.vae.dtype)
del pred_latents
self._empty_cache()

logger.debug(
"[generate_music] Before VAE decode: "
f"allocated={self._memory_allocated()/1024**3:.2f}GB, "
f"max={self._max_memory_allocated()/1024**3:.2f}GB"
)
using_mlx_vae = self.use_mlx_vae and self.mlx_vae is not None
vae_cpu = False
vae_device = None
if not using_mlx_vae:
vae_cpu = os.environ.get("ACESTEP_VAE_ON_CPU", "0").lower() in ("1", "true", "yes")
if not vae_cpu:
if self.device == "mps":
logger.info(
"[generate_music] MPS device: skipping VRAM check "
"(unified memory), keeping VAE on MPS"
)
else:
effective_free = get_effective_free_vram_gb()
logger.info(
"[generate_music] Effective free VRAM before VAE decode: "
f"{effective_free:.2f} GB"
)
if effective_free < 0.5:
logger.warning(
"[generate_music] Only "
f"{effective_free:.2f} GB free VRAM; auto-enabling CPU VAE decode"
)
vae_cpu = True
if vae_cpu:
logger.info("[generate_music] Moving VAE to CPU for decode (ACESTEP_VAE_ON_CPU=1)...")
vae_device = next(self.vae.parameters()).device
self.vae = self.vae.cpu()
pred_latents_for_decode = pred_latents_for_decode.cpu()
self._empty_cache()
try:
if use_tiled_decode:
logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
pred_wavs = self.tiled_decode(pred_latents_for_decode)
elif using_mlx_vae:
try:
pred_wavs = self._mlx_vae_decode(pred_latents_for_decode)
except Exception as exc:
logger.warning(
f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch"
)
decoder_output = self.vae.decode(pred_latents_for_decode)
pred_wavs = decoder_output.sample
del decoder_output
else:
decoder_output = self.vae.decode(pred_latents_for_decode)
pred_wavs = decoder_output.sample
del decoder_output
finally:
if vae_cpu and vae_device is not None:
logger.info("[generate_music] Restoring VAE to original device after CPU decode path...")
self.vae = self.vae.to(vae_device)
pred_latents_for_decode = pred_latents_for_decode.to(vae_device)
self._empty_cache()
logger.debug(
"[generate_music] After VAE decode: "
f"allocated={self._memory_allocated()/1024**3:.2f}GB, "
f"max={self._max_memory_allocated()/1024**3:.2f}GB"
)
del pred_latents_for_decode
if pred_wavs.dtype != torch.float32:
pred_wavs = pred_wavs.float()
peak = pred_wavs.abs().amax(dim=[1, 2], keepdim=True)
if torch.any(peak > 1.0):
pred_wavs = pred_wavs / peak.clamp(min=1.0)
self._empty_cache()
end_time = time.time()
time_costs["vae_decode_time_cost"] = end_time - start_time
time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
time_costs["offload_time_cost"] = self.current_offload_cost
return pred_wavs, pred_latents_cpu, time_costs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Module exceeds 200 LOC cap.

This file is ~201 lines, which crosses the 200 LOC hard cap. Consider extracting CPU/VRAM selection and post‑decode normalization into helpers to bring the module under the limit. As per coding guidelines: Target module size: optimal <= 150 LOC, hard cap 200 LOC.

🧰 Tools
🪛 Ruff (0.15.1)

[warning] 67-72: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 74-77: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 168-168: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music_decode.py` around lines 1 -
201, The module is over the 200-LOC cap due to a large
_decode_generate_music_pred_latents implementation; extract the CPU/VRAM
selection and the post-decode normalization into two helper methods to reduce
size. Specifically, move the logic that determines using_mlx_vae / vae_cpu /
vae_device and the code that moves the VAE to CPU and restores it into a new
private helper (e.g., _select_vae_device_for_decode(self,
pred_latents_for_decode) -> Tuple[pred_latents_for_decode, vae_cpu, vae_device])
and move the peak clamping + dtype coercion into another helper (e.g.,
_post_decode_normalize(self, pred_wavs) -> pred_wavs), update
_decode_generate_music_pred_latents to call these helpers, preserve all side
effects (self.vae device moves, calls to self._empty_cache(), and updates to
time_costs) and keep existing logging and exception fallback behavior intact.
Ensure any new helpers are placed in the same class and imported symbols remain
unchanged so callers and tests need no updates.

Comment on lines +180 to +184
if vae_cpu and vae_device is not None:
logger.info("[generate_music] Restoring VAE to original device after CPU decode path...")
self.vae = self.vae.to(vae_device)
pred_latents_for_decode = pred_latents_for_decode.to(vae_device)
self._empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t move pred_latents_for_decode back to GPU in the CPU-offload cleanup.

Line 183 re-allocates the full latent tensor on GPU after a low‑VRAM offload, which can OOM and even mask the original decode error. The tensor is deleted immediately after, so the move is unnecessary—only the VAE needs restoring.

✅ Proposed fix
                 if vae_cpu and vae_device is not None:
                     logger.info("[generate_music] Restoring VAE to original device after CPU decode path...")
                     self.vae = self.vae.to(vae_device)
-                    pred_latents_for_decode = pred_latents_for_decode.to(vae_device)
                 self._empty_cache()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if vae_cpu and vae_device is not None:
logger.info("[generate_music] Restoring VAE to original device after CPU decode path...")
self.vae = self.vae.to(vae_device)
pred_latents_for_decode = pred_latents_for_decode.to(vae_device)
self._empty_cache()
if vae_cpu and vae_device is not None:
logger.info("[generate_music] Restoring VAE to original device after CPU decode path...")
self.vae = self.vae.to(vae_device)
self._empty_cache()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music_decode.py` around lines 180 -
184, The code currently moves pred_latents_for_decode back to the GPU during the
CPU-offload cleanup which can OOM; in the restore block inside the CPU decode
path (the section checking vae_cpu and vae_device), only move self.vae back to
vae_device and do NOT call pred_latents_for_decode.to(vae_device). Remove the
line that re-allocates pred_latents_for_decode on the VAE restore path
(referencing pred_latents_for_decode and self.vae) so the large latent tensor
remains on CPU until it is deleted, then call self._empty_cache() as before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments