Refactor(handler): decompose generate_music orchestration (part 18)#626
Refactor(handler): decompose generate_music orchestration (part 18)#626ChuxiJ merged 2 commits intoace-step:mainfrom
Conversation
📝 WalkthroughWalkthroughExtracts the previous monolithic Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Handler as AceStepHandler\n(GenerateMusicMixin)
participant Service as GenerationService
participant Decoder as GenerateMusicDecodeMixin/VAE
participant Payload as GenerateMusicPayloadMixin
Client->>Handler: generate_music(captions, lyrics, params)
Handler->>Handler: validate readiness\nprepare inputs & seeds
Handler->>Service: run_service(inputs, progress_cb)
Service-->>Handler: outputs, latents, metadata
Handler->>Decoder: _prepare_generate_music_decode_state(outputs,...)
Decoder->>Decoder: validate & post-process latents
Decoder->>Decoder: decode latents -> waveforms
Decoder-->>Handler: pred_wavs, pred_latents_cpu, time_costs
Handler->>Payload: _build_generate_music_success_payload(outputs, pred_wavs, ...)
Payload-->>Handler: success payload
Handler-->>Client: payload
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Comment |
7d96537 to
daa0e17
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
acestep/core/generation/handler/generate_music.py (1)
22-55: Add type hints for untyped parameters in the public entry point.This keeps the public API consistent with the repository’s typing guidance.
Suggested update
-from typing import Any, Dict, List, Optional, Union +from collections.abc import Callable +from typing import Any, Dict, List, Optional, Union @@ - reference_audio=None, + reference_audio: Optional[Any] = None, @@ - src_audio=None, + src_audio: Optional[Any] = None, @@ - progress=None, + progress: Optional[Callable[..., Any]] = None,As per coding guidelines, "Add type hints for new/modified functions when practical in Python".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/core/generation/handler/generate_music.py` around lines 22 - 55, The public entry point generate_music has several untyped parameters; add explicit type hints for reference_audio, src_audio, and progress (and any other parameters lacking concrete types) to match the repository typing guidance—e.g., annotate reference_audio and src_audio with appropriate Optional[Any] or a more specific AudioType, and annotate progress with Optional[ProgressCallback] or Optional[Callable[[...], None]] (or the project's progress interface); update the function signature in generate_music accordingly and ensure any imports/types used for these annotations are added or referenced so the signature is fully typed while preserving the existing defaults and return type Dict[str, Any].acestep/core/generation/handler/generate_music_payload.py (1)
11-20: Add explicit type hints for tensor/progress parameters.This keeps the new mixin API self-documenting and aligned with the repo’s typing expectations.
Suggested update
-from typing import Any, Dict +from collections.abc import Callable +from typing import Any, Dict, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + import torch @@ - pred_wavs, - pred_latents_cpu, + pred_wavs: "torch.Tensor", + pred_latents_cpu: "torch.Tensor", @@ - progress: Any, + progress: Optional[Callable[..., Any]],As per coding guidelines, "Add type hints for new/modified functions when practical in Python".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/core/generation/handler/generate_music_payload.py` around lines 11 - 20, The function _build_generate_music_success_payload should add explicit type hints for the tensor and progress parameters: annotate pred_wavs and pred_latents_cpu as torch.Tensor (or Optional[torch.Tensor] if they can be None) and annotate progress with a concrete type such as tqdm.std.tqdm or Optional[tqdm.std.tqdm] (or a Callable progress callback type if that matches usage); update the signature to import and use typing (Optional/Union) and torch/tqdm as needed and adjust any callers or imports accordingly to keep the module type-correct.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@acestep/core/generation/handler/generate_music_payload.py`:
- Around line 11-20: The function _build_generate_music_success_payload should
add explicit type hints for the tensor and progress parameters: annotate
pred_wavs and pred_latents_cpu as torch.Tensor (or Optional[torch.Tensor] if
they can be None) and annotate progress with a concrete type such as
tqdm.std.tqdm or Optional[tqdm.std.tqdm] (or a Callable progress callback type
if that matches usage); update the signature to import and use typing
(Optional/Union) and torch/tqdm as needed and adjust any callers or imports
accordingly to keep the module type-correct.
In `@acestep/core/generation/handler/generate_music.py`:
- Around line 22-55: The public entry point generate_music has several untyped
parameters; add explicit type hints for reference_audio, src_audio, and progress
(and any other parameters lacking concrete types) to match the repository typing
guidance—e.g., annotate reference_audio and src_audio with appropriate
Optional[Any] or a more specific AudioType, and annotate progress with
Optional[ProgressCallback] or Optional[Callable[[...], None]] (or the project's
progress interface); update the function signature in generate_music accordingly
and ensure any imports/types used for these annotations are added or referenced
so the signature is fully typed while preserving the existing defaults and
return type Dict[str, Any].
daa0e17 to
ca3fd1b
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/core/generation/handler/generate_music_decode.py`:
- Around line 154-159: The VAE is moved to CPU when vae_cpu is true (see the
vae_cpu branch that sets vae_device = next(self.vae.parameters()).device,
self.vae = self.vae.cpu(), pred_latents_for_decode =
pred_latents_for_decode.cpu(), and calls self._empty_cache()), but if the
subsequent decode raises the original device restoration (which uses vae_device)
will be skipped; wrap the decode logic in a try/finally: record the original
device via vae_device before moving the VAE, move to CPU as now, run the decode
in the try block, and in the finally always restore self.vae back to vae_device
(and move pred_latents_for_decode back if needed) and call self._empty_cache()
so the VAE is returned to its original device even on exceptions.
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/core/generation/handler/generate_music_decode_test.py`:
- Around line 1-255: The test file is over the 200 LOC cap—extract the reusable
test helpers and fakes into a new helper module (e.g.,
tests/helpers/generate_music_fixtures.py) and import them into this test;
specifically move _load_generate_music_decode_module, _FakeDecodeOutput,
_FakeVae, _Host and any package-setup code (the package_paths/sys.path
manipulation) into that helper module, leaving only the test class
GenerateMusicDecodeMixinTests and its test methods plus minimal imports in
generate_music_decode_test.py; update generate_music_decode_test.py to import
GenerateMusicDecodeMixin (or the helper-provided fixtures) and reference the
same symbols so the tests run unchanged and the file size falls under the 200
LOC cap.
- Around line 14-42: The helper _load_generate_music_decode_module currently
injects stub packages into sys.modules and prepends repo_root to sys.path but
never cleans them up; add module-level trackers (e.g., a list inserted_sys_path
and a list injected_packages) and record the repo_root string and each
package_name you add from package_paths when creating package_module, then
implement tearDownModule() to undo those changes: remove the repo_root entry
from sys.path if it was inserted, and for each package_name in injected_packages
remove it from sys.modules (only if the value is the stub you created), ensuring
you don't disturb real packages; reference _load_generate_music_decode_module,
package_paths, sys.path, sys.modules, and implement tearDownModule to perform
the cleanup after tests.
In `@acestep/core/generation/handler/generate_music_decode.py`:
- Around line 180-184: The code currently moves pred_latents_for_decode back to
the GPU during the CPU-offload cleanup which can OOM; in the restore block
inside the CPU decode path (the section checking vae_cpu and vae_device), only
move self.vae back to vae_device and do NOT call
pred_latents_for_decode.to(vae_device). Remove the line that re-allocates
pred_latents_for_decode on the VAE restore path (referencing
pred_latents_for_decode and self.vae) so the large latent tensor remains on CPU
until it is deleted, then call self._empty_cache() as before.
- Around line 1-201: The module is over the 200-LOC cap due to a large
_decode_generate_music_pred_latents implementation; extract the CPU/VRAM
selection and the post-decode normalization into two helper methods to reduce
size. Specifically, move the logic that determines using_mlx_vae / vae_cpu /
vae_device and the code that moves the VAE to CPU and restores it into a new
private helper (e.g., _select_vae_device_for_decode(self,
pred_latents_for_decode) -> Tuple[pred_latents_for_decode, vae_cpu, vae_device])
and move the peak clamping + dtype coercion into another helper (e.g.,
_post_decode_normalize(self, pred_wavs) -> pred_wavs), update
_decode_generate_music_pred_latents to call these helpers, preserve all side
effects (self.vae device moves, calls to self._empty_cache(), and updates to
time_costs) and keep existing logging and exception fallback behavior intact.
Ensure any new helpers are placed in the same class and imported symbols remain
unchanged so callers and tests need no updates.
| """Tests for extracted ``generate_music`` decode helper mixin behavior.""" | ||
|
|
||
| import importlib.util | ||
| import types | ||
| import sys | ||
| import unittest | ||
| from contextlib import contextmanager | ||
| from pathlib import Path | ||
| from unittest.mock import patch | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def _load_generate_music_decode_module(): | ||
| """Load ``generate_music_decode.py`` from disk and return its module object. | ||
|
|
||
| Raises ``FileNotFoundError`` or ``ImportError`` when loading fails. | ||
| """ | ||
| repo_root = Path(__file__).resolve().parents[4] | ||
| if str(repo_root) not in sys.path: | ||
| sys.path.insert(0, str(repo_root)) | ||
| package_paths = { | ||
| "acestep": repo_root / "acestep", | ||
| "acestep.core": repo_root / "acestep" / "core", | ||
| "acestep.core.generation": repo_root / "acestep" / "core" / "generation", | ||
| "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", | ||
| } | ||
| for package_name, package_path in package_paths.items(): | ||
| if package_name in sys.modules: | ||
| continue | ||
| package_module = types.ModuleType(package_name) | ||
| package_module.__path__ = [str(package_path)] | ||
| sys.modules[package_name] = package_module | ||
| module_path = Path(__file__).with_name("generate_music_decode.py") | ||
| spec = importlib.util.spec_from_file_location( | ||
| "acestep.core.generation.handler.generate_music_decode", | ||
| module_path, | ||
| ) | ||
| module = importlib.util.module_from_spec(spec) | ||
| assert spec.loader is not None | ||
| spec.loader.exec_module(module) | ||
| return module | ||
|
|
||
|
|
||
| GENERATE_MUSIC_DECODE_MODULE = _load_generate_music_decode_module() | ||
| GenerateMusicDecodeMixin = GENERATE_MUSIC_DECODE_MODULE.GenerateMusicDecodeMixin | ||
|
|
||
|
|
||
| class _FakeDecodeOutput: | ||
| """Minimal VAE decode output container exposing ``sample`` attribute.""" | ||
|
|
||
| def __init__(self, sample: torch.Tensor): | ||
| """Store decoded sample tensor for mixin decode flow.""" | ||
| self.sample = sample | ||
|
|
||
|
|
||
| class _FakeVae: | ||
| """Minimal VAE stand-in with dtype, decode, and parameter iteration hooks.""" | ||
|
|
||
| def __init__(self): | ||
| """Initialize deterministic dtype/device state for decode tests.""" | ||
| self.dtype = torch.float32 | ||
| self._param = torch.nn.Parameter(torch.zeros(1)) | ||
|
|
||
| def decode(self, latents: torch.Tensor): | ||
| """Return deterministic decoded waveform output.""" | ||
| return _FakeDecodeOutput(torch.ones(latents.shape[0], 2, 8)) | ||
|
|
||
| def parameters(self): | ||
| """Yield one parameter so `.device` lookups remain valid.""" | ||
| yield self._param | ||
|
|
||
| def cpu(self): | ||
| """Return self for test-only CPU transfer calls.""" | ||
| return self | ||
|
|
||
| def to(self, *_args, **_kwargs): | ||
| """Return self for test-only device transfer calls.""" | ||
| return self | ||
|
|
||
|
|
||
| class _Host(GenerateMusicDecodeMixin): | ||
| """Minimal decode-mixin host exposing deterministic state for assertions.""" | ||
|
|
||
| def __init__(self): | ||
| """Initialize deterministic runtime state for decode tests.""" | ||
| self.current_offload_cost = 0.25 | ||
| self.debug_stats = False | ||
| self._last_diffusion_per_step_sec = None | ||
| self.estimate_calls = [] | ||
| self.progress_calls = [] | ||
| self.device = "cpu" | ||
| self.use_mlx_vae = True | ||
| self.mlx_vae = object() | ||
| self.vae = _FakeVae() | ||
|
|
||
| def _update_progress_estimate(self, **kwargs): | ||
| """Capture estimate updates for assertions.""" | ||
| self.estimate_calls.append(kwargs) | ||
|
|
||
| @contextmanager | ||
| def _load_model_context(self, _model_name): | ||
| """Provide no-op model context manager for decode tests.""" | ||
| yield | ||
|
|
||
| def _empty_cache(self): | ||
| """Provide no-op cache clear helper for decode tests.""" | ||
| return None | ||
|
|
||
| def _memory_allocated(self): | ||
| """Return deterministic allocated-memory value for debug logging.""" | ||
| return 0.0 | ||
|
|
||
| def _max_memory_allocated(self): | ||
| """Return deterministic max-memory value for debug logging.""" | ||
| return 0.0 | ||
|
|
||
| def _mlx_vae_decode(self, latents): | ||
| """Return deterministic decoded waveform for MLX decode branch.""" | ||
| _ = latents | ||
| return torch.ones(1, 2, 8) | ||
|
|
||
| def tiled_decode(self, latents): | ||
| """Return deterministic decoded waveform for tiled decode branch.""" | ||
| _ = latents | ||
| return torch.ones(1, 2, 8) | ||
|
|
||
|
|
||
| class GenerateMusicDecodeMixinTests(unittest.TestCase): | ||
| """Verify decode-state preparation and latent decode helper behavior.""" | ||
|
|
||
| def test_prepare_decode_state_updates_progress_estimates(self): | ||
| """It updates timing fields and progress estimate metadata for valid latents.""" | ||
| host = _Host() | ||
| outputs = { | ||
| "target_latents": torch.ones(1, 4, 3), | ||
| "time_costs": {"total_time_cost": 1.0, "diffusion_per_step_time_cost": 0.2}, | ||
| } | ||
| pred_latents, time_costs = host._prepare_generate_music_decode_state( | ||
| outputs=outputs, | ||
| infer_steps_for_progress=8, | ||
| actual_batch_size=1, | ||
| audio_duration=12.0, | ||
| latent_shift=0.0, | ||
| latent_rescale=1.0, | ||
| ) | ||
| self.assertEqual(tuple(pred_latents.shape), (1, 4, 3)) | ||
| self.assertEqual(time_costs["offload_time_cost"], 0.25) | ||
| self.assertEqual(host._last_diffusion_per_step_sec, 0.2) | ||
| self.assertEqual(host.estimate_calls[0]["infer_steps"], 8) | ||
|
|
||
| def test_prepare_decode_state_raises_for_nan_latents(self): | ||
| """It raises runtime error when diffusion latents contain NaN values.""" | ||
| host = _Host() | ||
| outputs = { | ||
| "target_latents": torch.tensor([[[float("nan")]]]), | ||
| "time_costs": {"total_time_cost": 1.0}, | ||
| } | ||
| with self.assertRaises(RuntimeError): | ||
| host._prepare_generate_music_decode_state( | ||
| outputs=outputs, | ||
| infer_steps_for_progress=8, | ||
| actual_batch_size=1, | ||
| audio_duration=None, | ||
| latent_shift=0.0, | ||
| latent_rescale=1.0, | ||
| ) | ||
|
|
||
| def test_decode_pred_latents_updates_decode_time_and_returns_cpu_latents(self): | ||
| """It decodes latents and updates decode timing metrics in time_costs.""" | ||
| host = _Host() | ||
| pred_latents = torch.ones(1, 4, 3) | ||
| time_costs = {"total_time_cost": 1.0} | ||
|
|
||
| def _progress(value, desc=None): | ||
| """Capture progress updates for assertions.""" | ||
| host.progress_calls.append((value, desc)) | ||
|
|
||
| with patch.object(GENERATE_MUSIC_DECODE_MODULE.time, "time", side_effect=[10.0, 11.5]): | ||
| pred_wavs, pred_latents_cpu, updated_costs = host._decode_generate_music_pred_latents( | ||
| pred_latents=pred_latents, | ||
| progress=_progress, | ||
| use_tiled_decode=False, | ||
| time_costs=time_costs, | ||
| ) | ||
|
|
||
| self.assertEqual(tuple(pred_wavs.shape), (1, 2, 8)) | ||
| self.assertEqual(pred_latents_cpu.device.type, "cpu") | ||
| self.assertAlmostEqual(updated_costs["vae_decode_time_cost"], 1.5, places=6) | ||
| self.assertAlmostEqual(updated_costs["total_time_cost"], 2.5, places=6) | ||
| self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6) | ||
| self.assertEqual(host.progress_calls[0][0], 0.8) | ||
|
|
||
| def test_decode_pred_latents_restores_vae_device_on_decode_error(self): | ||
| """It restores VAE device in the CPU-offload path even when decode raises.""" | ||
|
|
||
| class _FailingVae(_FakeVae): | ||
| """VAE double that raises during decode and records transfer calls.""" | ||
|
|
||
| def __init__(self): | ||
| """Initialize transfer call trackers for restoration assertions.""" | ||
| super().__init__() | ||
| self.cpu_calls = 0 | ||
| self.to_calls = [] | ||
|
|
||
| def decode(self, latents: torch.Tensor): | ||
| """Raise decode error to exercise restoration in finally branch.""" | ||
| _ = latents | ||
| raise RuntimeError("decode failed") | ||
|
|
||
| def cpu(self): | ||
| """Record explicit CPU transfer and return self.""" | ||
| self.cpu_calls += 1 | ||
| return self | ||
|
|
||
| def to(self, *args, **kwargs): | ||
| """Record restore transfer target and return self.""" | ||
| self.to_calls.append((args, kwargs)) | ||
| return self | ||
|
|
||
| class _FailingHost(_Host): | ||
| """Host variant that forces non-MLX VAE decode and tracks cache clears.""" | ||
|
|
||
| def __init__(self): | ||
| """Set non-MLX state so CPU offload path is exercised deterministically.""" | ||
| super().__init__() | ||
| self.use_mlx_vae = False | ||
| self.mlx_vae = None | ||
| self.vae = _FailingVae() | ||
| self.empty_cache_calls = 0 | ||
|
|
||
| def _empty_cache(self): | ||
| """Count cache-clear calls to verify finally cleanup runs.""" | ||
| self.empty_cache_calls += 1 | ||
|
|
||
| host = _FailingHost() | ||
| pred_latents = torch.ones(1, 4, 3) | ||
| time_costs = {"total_time_cost": 1.0} | ||
|
|
||
| with patch.dict(GENERATE_MUSIC_DECODE_MODULE.os.environ, {"ACESTEP_VAE_ON_CPU": "1"}, clear=False): | ||
| with self.assertRaisesRegex(RuntimeError, "decode failed"): | ||
| host._decode_generate_music_pred_latents( | ||
| pred_latents=pred_latents, | ||
| progress=None, | ||
| use_tiled_decode=False, | ||
| time_costs=time_costs, | ||
| ) | ||
|
|
||
| self.assertEqual(host.vae.cpu_calls, 1) | ||
| self.assertEqual(len(host.vae.to_calls), 1) | ||
| self.assertGreaterEqual(host.empty_cache_calls, 2) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Module exceeds 200 LOC cap.
This test module is ~255 lines, which is over the 200 LOC hard cap. Please split the fakes/loader helpers into a smaller helper module or shared fixture to bring it under the limit. As per coding guidelines: Target module size: optimal <= 150 LOC, hard cap 200 LOC.
🧰 Tools
🪛 Ruff (0.15.1)
[warning] 209-209: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/core/generation/handler/generate_music_decode_test.py` around lines 1
- 255, The test file is over the 200 LOC cap—extract the reusable test helpers
and fakes into a new helper module (e.g.,
tests/helpers/generate_music_fixtures.py) and import them into this test;
specifically move _load_generate_music_decode_module, _FakeDecodeOutput,
_FakeVae, _Host and any package-setup code (the package_paths/sys.path
manipulation) into that helper module, leaving only the test class
GenerateMusicDecodeMixinTests and its test methods plus minimal imports in
generate_music_decode_test.py; update generate_music_decode_test.py to import
GenerateMusicDecodeMixin (or the helper-provided fixtures) and reference the
same symbols so the tests run unchanged and the file size falls under the 200
LOC cap.
| def _load_generate_music_decode_module(): | ||
| """Load ``generate_music_decode.py`` from disk and return its module object. | ||
|
|
||
| Raises ``FileNotFoundError`` or ``ImportError`` when loading fails. | ||
| """ | ||
| repo_root = Path(__file__).resolve().parents[4] | ||
| if str(repo_root) not in sys.path: | ||
| sys.path.insert(0, str(repo_root)) | ||
| package_paths = { | ||
| "acestep": repo_root / "acestep", | ||
| "acestep.core": repo_root / "acestep" / "core", | ||
| "acestep.core.generation": repo_root / "acestep" / "core" / "generation", | ||
| "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", | ||
| } | ||
| for package_name, package_path in package_paths.items(): | ||
| if package_name in sys.modules: | ||
| continue | ||
| package_module = types.ModuleType(package_name) | ||
| package_module.__path__ = [str(package_path)] | ||
| sys.modules[package_name] = package_module | ||
| module_path = Path(__file__).with_name("generate_music_decode.py") | ||
| spec = importlib.util.spec_from_file_location( | ||
| "acestep.core.generation.handler.generate_music_decode", | ||
| module_path, | ||
| ) | ||
| module = importlib.util.module_from_spec(spec) | ||
| assert spec.loader is not None | ||
| spec.loader.exec_module(module) | ||
| return module |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
rg -n -C2 --type=py -g '*_test.py' -e '^\s*(import|from)\s+acestep\b'Repository: ace-step/ACE-Step-1.5
Length of output: 15331
🏁 Script executed:
# Check the actual test file to see if tearDownModule already exists and how the function is used
cat -n acestep/core/generation/handler/generate_music_decode_test.pyRepository: ace-step/ACE-Step-1.5
Length of output: 11481
🏁 Script executed:
# Also check if there's any cleanup in conftest.py or other test setup files
find . -name "conftest.py" -o -name "test_*.py" | head -20Repository: ace-step/ACE-Step-1.5
Length of output: 47
Clean up stub packages and sys.path modifications after module tests complete.
_load_generate_music_decode_module() injects package stubs and modifies sys.path at module import time (line 45) without cleanup. In a shared unittest process, stub acestep packages can mask the real package for other tests that follow. Confirm: 20+ other test files import from acestep, making contamination possible.
Add module-level tracking lists and a tearDownModule() function to remove injected entries:
Cleanup pattern
+_TEMP_SYS_MODULES = []
+_TEMP_SYS_PATHS = []
+
def _load_generate_music_decode_module():
"""Load ``generate_music_decode.py`` from disk and return its module object.
@@
repo_root = Path(__file__).resolve().parents[4]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
+ _TEMP_SYS_PATHS.append(str(repo_root))
@@
for package_name, package_path in package_paths.items():
if package_name in sys.modules:
continue
package_module = types.ModuleType(package_name)
package_module.__path__ = [str(package_path)]
sys.modules[package_name] = package_module
+ _TEMP_SYS_MODULES.append(package_name)
@@
spec.loader.exec_module(module)
return module
+
+
+def tearDownModule():
+ for name in _TEMP_SYS_MODULES:
+ sys.modules.pop(name, None)
+ for path in _TEMP_SYS_PATHS:
+ if path in sys.path:
+ sys.path.remove(path)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _load_generate_music_decode_module(): | |
| """Load ``generate_music_decode.py`` from disk and return its module object. | |
| Raises ``FileNotFoundError`` or ``ImportError`` when loading fails. | |
| """ | |
| repo_root = Path(__file__).resolve().parents[4] | |
| if str(repo_root) not in sys.path: | |
| sys.path.insert(0, str(repo_root)) | |
| package_paths = { | |
| "acestep": repo_root / "acestep", | |
| "acestep.core": repo_root / "acestep" / "core", | |
| "acestep.core.generation": repo_root / "acestep" / "core" / "generation", | |
| "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", | |
| } | |
| for package_name, package_path in package_paths.items(): | |
| if package_name in sys.modules: | |
| continue | |
| package_module = types.ModuleType(package_name) | |
| package_module.__path__ = [str(package_path)] | |
| sys.modules[package_name] = package_module | |
| module_path = Path(__file__).with_name("generate_music_decode.py") | |
| spec = importlib.util.spec_from_file_location( | |
| "acestep.core.generation.handler.generate_music_decode", | |
| module_path, | |
| ) | |
| module = importlib.util.module_from_spec(spec) | |
| assert spec.loader is not None | |
| spec.loader.exec_module(module) | |
| return module | |
| _TEMP_SYS_MODULES = [] | |
| _TEMP_SYS_PATHS = [] | |
| def _load_generate_music_decode_module(): | |
| """Load ``generate_music_decode.py`` from disk and return its module object. | |
| Raises ``FileNotFoundError`` or ``ImportError`` when loading fails. | |
| """ | |
| repo_root = Path(__file__).resolve().parents[4] | |
| if str(repo_root) not in sys.path: | |
| sys.path.insert(0, str(repo_root)) | |
| _TEMP_SYS_PATHS.append(str(repo_root)) | |
| package_paths = { | |
| "acestep": repo_root / "acestep", | |
| "acestep.core": repo_root / "acestep" / "core", | |
| "acestep.core.generation": repo_root / "acestep" / "core" / "generation", | |
| "acestep.core.generation.handler": repo_root / "acestep" / "core" / "generation" / "handler", | |
| } | |
| for package_name, package_path in package_paths.items(): | |
| if package_name in sys.modules: | |
| continue | |
| package_module = types.ModuleType(package_name) | |
| package_module.__path__ = [str(package_path)] | |
| sys.modules[package_name] = package_module | |
| _TEMP_SYS_MODULES.append(package_name) | |
| module_path = Path(__file__).with_name("generate_music_decode.py") | |
| spec = importlib.util.spec_from_file_location( | |
| "acestep.core.generation.handler.generate_music_decode", | |
| module_path, | |
| ) | |
| module = importlib.util.module_from_spec(spec) | |
| assert spec.loader is not None | |
| spec.loader.exec_module(module) | |
| return module | |
| def tearDownModule(): | |
| for name in _TEMP_SYS_MODULES: | |
| sys.modules.pop(name, None) | |
| for path in _TEMP_SYS_PATHS: | |
| if path in sys.path: | |
| sys.path.remove(path) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/core/generation/handler/generate_music_decode_test.py` around lines
14 - 42, The helper _load_generate_music_decode_module currently injects stub
packages into sys.modules and prepends repo_root to sys.path but never cleans
them up; add module-level trackers (e.g., a list inserted_sys_path and a list
injected_packages) and record the repo_root string and each package_name you add
from package_paths when creating package_module, then implement tearDownModule()
to undo those changes: remove the repo_root entry from sys.path if it was
inserted, and for each package_name in injected_packages remove it from
sys.modules (only if the value is the stub you created), ensuring you don't
disturb real packages; reference _load_generate_music_decode_module,
package_paths, sys.path, sys.modules, and implement tearDownModule to perform
the cleanup after tests.
| """Decode/validation helpers for ``generate_music`` orchestration.""" | ||
|
|
||
| import os | ||
| import time | ||
| from typing import Any, Dict, Optional, Tuple | ||
|
|
||
| import torch | ||
| from loguru import logger | ||
|
|
||
| from acestep.gpu_config import get_effective_free_vram_gb | ||
|
|
||
|
|
||
| class GenerateMusicDecodeMixin: | ||
| """Validate generated latents and decode them into waveform tensors.""" | ||
|
|
||
| def _prepare_generate_music_decode_state( | ||
| self, | ||
| outputs: Dict[str, Any], | ||
| infer_steps_for_progress: int, | ||
| actual_batch_size: int, | ||
| audio_duration: Optional[float], | ||
| latent_shift: float, | ||
| latent_rescale: float, | ||
| ) -> Tuple[torch.Tensor, Dict[str, Any]]: | ||
| """Collect decode inputs and validate raw diffusion latents. | ||
|
|
||
| Args: | ||
| outputs: ``service_generate`` output payload containing target latents and timings. | ||
| infer_steps_for_progress: Effective diffusion step count for estimates. | ||
| actual_batch_size: Effective generation batch size. | ||
| audio_duration: Optional generation duration in seconds. | ||
| latent_shift: Additive latent post-processing shift. | ||
| latent_rescale: Multiplicative latent post-processing scale. | ||
|
|
||
| Returns: | ||
| Tuple containing validated ``pred_latents`` and mutable ``time_costs``. | ||
|
|
||
| Raises: | ||
| RuntimeError: If latents contain NaN/Inf values or collapse to all zeros. | ||
| """ | ||
| logger.info("[generate_music] Model generation completed. Decoding latents...") | ||
| pred_latents = outputs["target_latents"] | ||
| time_costs = outputs["time_costs"] | ||
| time_costs["offload_time_cost"] = self.current_offload_cost | ||
|
|
||
| per_step = time_costs.get("diffusion_per_step_time_cost") | ||
| if isinstance(per_step, (int, float)) and per_step > 0: | ||
| self._last_diffusion_per_step_sec = float(per_step) | ||
| self._update_progress_estimate( | ||
| per_step_sec=float(per_step), | ||
| infer_steps=infer_steps_for_progress, | ||
| batch_size=actual_batch_size, | ||
| duration_sec=audio_duration if audio_duration and audio_duration > 0 else None, | ||
| ) | ||
|
|
||
| if self.debug_stats: | ||
| logger.debug( | ||
| f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} " | ||
| f"{pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} " | ||
| f"{pred_latents.std()=}" | ||
| ) | ||
| else: | ||
| logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype}") | ||
| logger.debug(f"[generate_music] time_costs: {time_costs}") | ||
|
|
||
| if torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any(): | ||
| raise RuntimeError( | ||
| "Generation produced NaN or Inf latents. " | ||
| "This usually indicates a checkpoint/config mismatch " | ||
| "or unsupported quantization/backend combination. " | ||
| "Try running with --backend pt or verify your model checkpoints match this release." | ||
| ) | ||
| if pred_latents.numel() > 0 and pred_latents.abs().sum() == 0: | ||
| raise RuntimeError( | ||
| "Generation produced zero latents. " | ||
| "This usually indicates a checkpoint/config mismatch or unsupported setup." | ||
| ) | ||
| if latent_shift != 0.0 or latent_rescale != 1.0: | ||
| logger.info( | ||
| f"[generate_music] Applying latent post-processing: shift={latent_shift}, " | ||
| f"rescale={latent_rescale}" | ||
| ) | ||
| if self.debug_stats: | ||
| logger.debug( | ||
| f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, " | ||
| f"max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, " | ||
| f"std={pred_latents.std():.4f}" | ||
| ) | ||
| pred_latents = pred_latents * latent_rescale + latent_shift | ||
| if self.debug_stats: | ||
| logger.debug( | ||
| f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, " | ||
| f"max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, " | ||
| f"std={pred_latents.std():.4f}" | ||
| ) | ||
| return pred_latents, time_costs | ||
|
|
||
| def _decode_generate_music_pred_latents( | ||
| self, | ||
| pred_latents: torch.Tensor, | ||
| progress: Any, | ||
| use_tiled_decode: bool, | ||
| time_costs: Dict[str, Any], | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: | ||
| """Decode predicted latents and update decode timing metrics. | ||
|
|
||
| Args: | ||
| pred_latents: Predicted latent tensor shaped ``[batch, frames, dim]``. | ||
| progress: Optional progress callback. | ||
| use_tiled_decode: Whether tiled VAE decode should be used. | ||
| time_costs: Mutable time-cost payload from service generation. | ||
|
|
||
| Returns: | ||
| Tuple of decoded waveforms, CPU latents, and updated time-cost payload. | ||
| """ | ||
| if progress: | ||
| progress(0.8, desc="Decoding audio...") | ||
| logger.info("[generate_music] Decoding latents with VAE...") | ||
| start_time = time.time() | ||
| with torch.inference_mode(): | ||
| with self._load_model_context("vae"): | ||
| pred_latents_cpu = pred_latents.detach().cpu() | ||
| pred_latents_for_decode = pred_latents.transpose(1, 2).contiguous().to(self.vae.dtype) | ||
| del pred_latents | ||
| self._empty_cache() | ||
|
|
||
| logger.debug( | ||
| "[generate_music] Before VAE decode: " | ||
| f"allocated={self._memory_allocated()/1024**3:.2f}GB, " | ||
| f"max={self._max_memory_allocated()/1024**3:.2f}GB" | ||
| ) | ||
| using_mlx_vae = self.use_mlx_vae and self.mlx_vae is not None | ||
| vae_cpu = False | ||
| vae_device = None | ||
| if not using_mlx_vae: | ||
| vae_cpu = os.environ.get("ACESTEP_VAE_ON_CPU", "0").lower() in ("1", "true", "yes") | ||
| if not vae_cpu: | ||
| if self.device == "mps": | ||
| logger.info( | ||
| "[generate_music] MPS device: skipping VRAM check " | ||
| "(unified memory), keeping VAE on MPS" | ||
| ) | ||
| else: | ||
| effective_free = get_effective_free_vram_gb() | ||
| logger.info( | ||
| "[generate_music] Effective free VRAM before VAE decode: " | ||
| f"{effective_free:.2f} GB" | ||
| ) | ||
| if effective_free < 0.5: | ||
| logger.warning( | ||
| "[generate_music] Only " | ||
| f"{effective_free:.2f} GB free VRAM; auto-enabling CPU VAE decode" | ||
| ) | ||
| vae_cpu = True | ||
| if vae_cpu: | ||
| logger.info("[generate_music] Moving VAE to CPU for decode (ACESTEP_VAE_ON_CPU=1)...") | ||
| vae_device = next(self.vae.parameters()).device | ||
| self.vae = self.vae.cpu() | ||
| pred_latents_for_decode = pred_latents_for_decode.cpu() | ||
| self._empty_cache() | ||
| try: | ||
| if use_tiled_decode: | ||
| logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") | ||
| pred_wavs = self.tiled_decode(pred_latents_for_decode) | ||
| elif using_mlx_vae: | ||
| try: | ||
| pred_wavs = self._mlx_vae_decode(pred_latents_for_decode) | ||
| except Exception as exc: | ||
| logger.warning( | ||
| f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch" | ||
| ) | ||
| decoder_output = self.vae.decode(pred_latents_for_decode) | ||
| pred_wavs = decoder_output.sample | ||
| del decoder_output | ||
| else: | ||
| decoder_output = self.vae.decode(pred_latents_for_decode) | ||
| pred_wavs = decoder_output.sample | ||
| del decoder_output | ||
| finally: | ||
| if vae_cpu and vae_device is not None: | ||
| logger.info("[generate_music] Restoring VAE to original device after CPU decode path...") | ||
| self.vae = self.vae.to(vae_device) | ||
| pred_latents_for_decode = pred_latents_for_decode.to(vae_device) | ||
| self._empty_cache() | ||
| logger.debug( | ||
| "[generate_music] After VAE decode: " | ||
| f"allocated={self._memory_allocated()/1024**3:.2f}GB, " | ||
| f"max={self._max_memory_allocated()/1024**3:.2f}GB" | ||
| ) | ||
| del pred_latents_for_decode | ||
| if pred_wavs.dtype != torch.float32: | ||
| pred_wavs = pred_wavs.float() | ||
| peak = pred_wavs.abs().amax(dim=[1, 2], keepdim=True) | ||
| if torch.any(peak > 1.0): | ||
| pred_wavs = pred_wavs / peak.clamp(min=1.0) | ||
| self._empty_cache() | ||
| end_time = time.time() | ||
| time_costs["vae_decode_time_cost"] = end_time - start_time | ||
| time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"] | ||
| time_costs["offload_time_cost"] = self.current_offload_cost | ||
| return pred_wavs, pred_latents_cpu, time_costs |
There was a problem hiding this comment.
Module exceeds 200 LOC cap.
This file is ~201 lines, which crosses the 200 LOC hard cap. Consider extracting CPU/VRAM selection and post‑decode normalization into helpers to bring the module under the limit. As per coding guidelines: Target module size: optimal <= 150 LOC, hard cap 200 LOC.
🧰 Tools
🪛 Ruff (0.15.1)
[warning] 67-72: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 74-77: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 168-168: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/core/generation/handler/generate_music_decode.py` around lines 1 -
201, The module is over the 200-LOC cap due to a large
_decode_generate_music_pred_latents implementation; extract the CPU/VRAM
selection and the post-decode normalization into two helper methods to reduce
size. Specifically, move the logic that determines using_mlx_vae / vae_cpu /
vae_device and the code that moves the VAE to CPU and restores it into a new
private helper (e.g., _select_vae_device_for_decode(self,
pred_latents_for_decode) -> Tuple[pred_latents_for_decode, vae_cpu, vae_device])
and move the peak clamping + dtype coercion into another helper (e.g.,
_post_decode_normalize(self, pred_wavs) -> pred_wavs), update
_decode_generate_music_pred_latents to call these helpers, preserve all side
effects (self.vae device moves, calls to self._empty_cache(), and updates to
time_costs) and keep existing logging and exception fallback behavior intact.
Ensure any new helpers are placed in the same class and imported symbols remain
unchanged so callers and tests need no updates.
| if vae_cpu and vae_device is not None: | ||
| logger.info("[generate_music] Restoring VAE to original device after CPU decode path...") | ||
| self.vae = self.vae.to(vae_device) | ||
| pred_latents_for_decode = pred_latents_for_decode.to(vae_device) | ||
| self._empty_cache() |
There was a problem hiding this comment.
Don’t move pred_latents_for_decode back to GPU in the CPU-offload cleanup.
Line 183 re-allocates the full latent tensor on GPU after a low‑VRAM offload, which can OOM and even mask the original decode error. The tensor is deleted immediately after, so the move is unnecessary—only the VAE needs restoring.
✅ Proposed fix
if vae_cpu and vae_device is not None:
logger.info("[generate_music] Restoring VAE to original device after CPU decode path...")
self.vae = self.vae.to(vae_device)
- pred_latents_for_decode = pred_latents_for_decode.to(vae_device)
self._empty_cache()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if vae_cpu and vae_device is not None: | |
| logger.info("[generate_music] Restoring VAE to original device after CPU decode path...") | |
| self.vae = self.vae.to(vae_device) | |
| pred_latents_for_decode = pred_latents_for_decode.to(vae_device) | |
| self._empty_cache() | |
| if vae_cpu and vae_device is not None: | |
| logger.info("[generate_music] Restoring VAE to original device after CPU decode path...") | |
| self.vae = self.vae.to(vae_device) | |
| self._empty_cache() |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/core/generation/handler/generate_music_decode.py` around lines 180 -
184, The code currently moves pred_latents_for_decode back to the GPU during the
CPU-offload cleanup which can OOM; in the restore block inside the CPU decode
path (the section checking vae_cpu and vae_device), only move self.vae back to
vae_device and do NOT call pred_latents_for_decode.to(vae_device). Remove the
line that re-allocates pred_latents_for_decode on the VAE restore path
(referencing pred_latents_for_decode and self.vae) so the large latent tensor
remains on CPU until it is deleted, then call self._empty_cache() as before.
Summary
generate_musicorchestration fromacestep/handler.pyinto focused mixins.AceStepHandler.generate_musicbehavior by composing new mixins in handler inheritance.Changes
generate_musicorchestration mixin module.AceStepHandlermixin composition.generate_musicimplementation fromacestep/handler.py.Validation
python acestep/core/generation/handler/generate_music_test.pypython acestep/core/generation/handler/generate_music_decode_test.pypython acestep/core/generation/handler/generate_music_payload_test.pypython acestep/core/generation/handler/mlx_dit_init_test.pypython acestep/core/generation/handler/mlx_vae_init_test.pypython acestep/core/generation/handler/mlx_vae_native_test.pySummary by CodeRabbit
New Features
Improved Decoding
Tests
Refactor