Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions acestep/ui/gradio/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["timesig_auto"],
generation_section["vocal_lang_auto"],
generation_section["duration_auto"],
# State-leakage fix: clear stale values on mode switch (indices 42-43)
generation_section["text2music_audio_code_string"],
generation_section["src_audio"],
]
)

Expand Down Expand Up @@ -732,6 +735,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["timesig_auto"],
generation_section["vocal_lang_auto"],
generation_section["duration_auto"],
# State-leakage fix: clear stale values on mode switch (indices 42-43)
generation_section["text2music_audio_code_string"],
generation_section["src_audio"],
]
for btn_idx in range(1, 9):
results_section[f"send_to_remix_btn_{btn_idx}"].click(
Expand Down
20 changes: 18 additions & 2 deletions acestep/ui/gradio/events/generation/mode_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def compute_mode_ui_updates(mode: str, llm_handler=None, previous_mode: str = "C
previous_mode: The mode that was active before this switch.

Returns:
Tuple of 42 gr.update objects matching the standard mode-change
Tuple of 44 gr.update objects matching the standard mode-change
output list (see event wiring in events/__init__.py).
"""
task_type = MODE_TO_TASK_TYPE.get(mode, "text2music")
Expand Down Expand Up @@ -126,6 +126,20 @@ def compute_mode_ui_updates(mode: str, llm_handler=None, previous_mode: str = "C
auto_vocal_lang_update = gr.update()
auto_duration_update = gr.update()

# Clear stale audio codes when leaving Custom mode to prevent
# them from leaking into Remix/other modes (state-leakage bug fix).
if is_custom:
audio_codes_update = gr.update(visible=True)
else:
audio_codes_update = gr.update(value="", visible=False)

# Clear src_audio when entering a mode that doesn't use it
# (Custom, Simple) to prevent stale audio from leaking.
if show_src_audio:
src_audio_update = gr.update()
else:
src_audio_update = gr.update(value=None)

return (
gr.update(visible=show_simple), # 0: simple_mode_group
gr.update(visible=show_custom_group), # 1: custom_mode_group
Expand Down Expand Up @@ -169,6 +183,8 @@ def compute_mode_ui_updates(mode: str, llm_handler=None, previous_mode: str = "C
auto_timesig_update, # 39: timesig_auto
auto_vocal_lang_update, # 40: vocal_lang_auto
auto_duration_update, # 41: duration_auto
audio_codes_update, # 42: text2music_audio_code_string
src_audio_update, # 43: src_audio
)


Expand Down Expand Up @@ -301,7 +317,7 @@ def handle_generation_mode_change(mode: str, previous_mode: str, llm_handler=Non
llm_handler: Optional LLM handler.

Returns:
Tuple of 42 updates for UI components.
Tuple of 44 updates for UI components.
"""
return compute_mode_ui_updates(mode, llm_handler, previous_mode=previous_mode)

Expand Down
104 changes: 104 additions & 0 deletions acestep/ui/gradio/events/generation/mode_ui_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Unit tests for mode_ui state-clearing behavior on mode switch.

Verifies that compute_mode_ui_updates correctly clears stale
text2music_audio_code_string and src_audio values when switching
between modes, preventing the state-leakage noise bug.
"""

import unittest
from types import SimpleNamespace

try:
from acestep.ui.gradio.events.generation.mode_ui import compute_mode_ui_updates
_IMPORT_ERROR = None
except Exception as exc: # pragma: no cover - environment dependency guard
compute_mode_ui_updates = None
_IMPORT_ERROR = exc

# Output indices for the two new state-clearing outputs
_IDX_AUDIO_CODES = 42
_IDX_SRC_AUDIO = 43
_EXPECTED_TUPLE_LENGTH = 44


@unittest.skipIf(compute_mode_ui_updates is None,
f"compute_mode_ui_updates import unavailable: {_IMPORT_ERROR}")
class ModeUiStateClearingTests(unittest.TestCase):
"""Tests that mode switches clear stale UI state to prevent noise."""

def test_tuple_length(self):
"""compute_mode_ui_updates should return exactly 44 elements."""
result = compute_mode_ui_updates("Custom")
self.assertEqual(len(result), _EXPECTED_TUPLE_LENGTH)

def test_custom_mode_preserves_audio_codes(self):
"""In Custom mode, audio_codes textbox should be visible but not cleared."""
result = compute_mode_ui_updates("Custom")
codes_update = result[_IDX_AUDIO_CODES]
# Should only set visibility, not clear the value
self.assertTrue(codes_update.get("visible"))
self.assertNotIn("value", codes_update)

def test_remix_mode_clears_audio_codes(self):
"""Switching to Remix should clear the audio_codes textbox value."""
result = compute_mode_ui_updates("Remix", previous_mode="Custom")
codes_update = result[_IDX_AUDIO_CODES]
self.assertEqual(codes_update.get("value"), "")
self.assertFalse(codes_update.get("visible"))

def test_simple_mode_clears_audio_codes(self):
"""Switching to Simple should clear the audio_codes textbox value."""
result = compute_mode_ui_updates("Simple", previous_mode="Custom")
codes_update = result[_IDX_AUDIO_CODES]
self.assertEqual(codes_update.get("value"), "")

def test_repaint_mode_clears_audio_codes(self):
"""Switching to Repaint should clear the audio_codes textbox value."""
result = compute_mode_ui_updates("Repaint", previous_mode="Custom")
codes_update = result[_IDX_AUDIO_CODES]
self.assertEqual(codes_update.get("value"), "")

def test_custom_mode_clears_src_audio(self):
"""Switching to Custom should clear src_audio (no source audio needed)."""
result = compute_mode_ui_updates("Custom", previous_mode="Remix")
src_update = result[_IDX_SRC_AUDIO]
self.assertIsNone(src_update.get("value"))

def test_simple_mode_clears_src_audio(self):
"""Switching to Simple should clear src_audio."""
result = compute_mode_ui_updates("Simple", previous_mode="Remix")
src_update = result[_IDX_SRC_AUDIO]
self.assertIsNone(src_update.get("value"))

def test_remix_mode_preserves_src_audio(self):
"""In Remix mode, src_audio should not be cleared (it's needed)."""
result = compute_mode_ui_updates("Remix")
src_update = result[_IDX_SRC_AUDIO]
# Should be a no-op update (no value key)
self.assertNotIn("value", src_update)

def test_repaint_mode_preserves_src_audio(self):
"""In Repaint mode, src_audio should not be cleared (it's needed)."""
result = compute_mode_ui_updates("Repaint")
src_update = result[_IDX_SRC_AUDIO]
self.assertNotIn("value", src_update)

def test_round_trip_remix_to_custom_clears_both(self):
"""Switching Remix -> Custom should clear both codes and src_audio."""
result = compute_mode_ui_updates("Custom", previous_mode="Remix")
codes_update = result[_IDX_AUDIO_CODES]
src_update = result[_IDX_SRC_AUDIO]
# Custom mode should not clear codes (it uses them)
self.assertTrue(codes_update.get("visible"))
# But src_audio should be cleared
self.assertIsNone(src_update.get("value"))

def test_round_trip_custom_to_remix_clears_codes(self):
"""Switching Custom -> Remix should clear stale audio codes."""
result = compute_mode_ui_updates("Remix", previous_mode="Custom")
codes_update = result[_IDX_AUDIO_CODES]
self.assertEqual(codes_update.get("value"), "")


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions acestep/ui/gradio/events/results/audio_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def send_audio_to_remix(audio_file, lm_metadata, current_lyrics, current_caption
llm_handler: Optional LLM handler.

Returns:
46-tuple of Gradio updates (4 data + 42 mode-UI).
48-tuple of Gradio updates (4 data + 44 mode-UI).
"""
n_outputs = 46
n_outputs = 48
if audio_file is None:
return (gr.skip(),) * n_outputs

Expand Down Expand Up @@ -103,9 +103,9 @@ def send_audio_to_repaint(audio_file, lm_metadata, current_lyrics, current_capti
llm_handler: Optional LLM handler.

Returns:
46-tuple of Gradio updates (4 data + 42 mode-UI).
48-tuple of Gradio updates (4 data + 44 mode-UI).
"""
n_outputs = 46
n_outputs = 48
if audio_file is None:
return (gr.skip(),) * n_outputs

Expand Down
6 changes: 6 additions & 0 deletions acestep/ui/gradio/events/results/generation_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def generate_with_progress(
if task_type == "text2music":
src_audio = None

# Defensive guard: cover/repaint/extract/lego tasks should never use
# stale audio codes from the text2music_audio_code_string textbox.
# Only text2music (Custom mode) with thinking disabled should pass codes.
if task_type != "text2music":
text2music_audio_code_string = ""

gen_params = GenerationParams(
task_type=task_type,
instruction=instruction_display_gen,
Expand Down