diff --git a/acestep/ui/gradio/events/__init__.py b/acestep/ui/gradio/events/__init__.py index 86502ae5..177b4b58 100644 --- a/acestep/ui/gradio/events/__init__.py +++ b/acestep/ui/gradio/events/__init__.py @@ -466,6 +466,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase generation_section["timesig_auto"], generation_section["vocal_lang_auto"], generation_section["duration_auto"], + # State-leakage fix: clear stale values on mode switch (indices 42-43) + generation_section["text2music_audio_code_string"], + generation_section["src_audio"], ] ) @@ -732,6 +735,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase generation_section["timesig_auto"], generation_section["vocal_lang_auto"], generation_section["duration_auto"], + # State-leakage fix: clear stale values on mode switch (indices 42-43) + generation_section["text2music_audio_code_string"], + generation_section["src_audio"], ] for btn_idx in range(1, 9): results_section[f"send_to_remix_btn_{btn_idx}"].click( diff --git a/acestep/ui/gradio/events/generation/mode_ui.py b/acestep/ui/gradio/events/generation/mode_ui.py index 981b8973..1b61bba4 100644 --- a/acestep/ui/gradio/events/generation/mode_ui.py +++ b/acestep/ui/gradio/events/generation/mode_ui.py @@ -26,7 +26,7 @@ def compute_mode_ui_updates(mode: str, llm_handler=None, previous_mode: str = "C previous_mode: The mode that was active before this switch. Returns: - Tuple of 42 gr.update objects matching the standard mode-change + Tuple of 44 gr.update objects matching the standard mode-change output list (see event wiring in events/__init__.py). """ task_type = MODE_TO_TASK_TYPE.get(mode, "text2music") @@ -126,6 +126,20 @@ def compute_mode_ui_updates(mode: str, llm_handler=None, previous_mode: str = "C auto_vocal_lang_update = gr.update() auto_duration_update = gr.update() + # Clear stale audio codes when leaving Custom mode to prevent + # them from leaking into Remix/other modes (state-leakage bug fix). + if is_custom: + audio_codes_update = gr.update(visible=True) + else: + audio_codes_update = gr.update(value="", visible=False) + + # Clear src_audio when entering a mode that doesn't use it + # (Custom, Simple) to prevent stale audio from leaking. + if show_src_audio: + src_audio_update = gr.update() + else: + src_audio_update = gr.update(value=None) + return ( gr.update(visible=show_simple), # 0: simple_mode_group gr.update(visible=show_custom_group), # 1: custom_mode_group @@ -169,6 +183,8 @@ def compute_mode_ui_updates(mode: str, llm_handler=None, previous_mode: str = "C auto_timesig_update, # 39: timesig_auto auto_vocal_lang_update, # 40: vocal_lang_auto auto_duration_update, # 41: duration_auto + audio_codes_update, # 42: text2music_audio_code_string + src_audio_update, # 43: src_audio ) @@ -301,7 +317,7 @@ def handle_generation_mode_change(mode: str, previous_mode: str, llm_handler=Non llm_handler: Optional LLM handler. Returns: - Tuple of 42 updates for UI components. + Tuple of 44 updates for UI components. """ return compute_mode_ui_updates(mode, llm_handler, previous_mode=previous_mode) diff --git a/acestep/ui/gradio/events/generation/mode_ui_test.py b/acestep/ui/gradio/events/generation/mode_ui_test.py new file mode 100644 index 00000000..94533328 --- /dev/null +++ b/acestep/ui/gradio/events/generation/mode_ui_test.py @@ -0,0 +1,104 @@ +"""Unit tests for mode_ui state-clearing behavior on mode switch. + +Verifies that compute_mode_ui_updates correctly clears stale +text2music_audio_code_string and src_audio values when switching +between modes, preventing the state-leakage noise bug. +""" + +import unittest +from types import SimpleNamespace + +try: + from acestep.ui.gradio.events.generation.mode_ui import compute_mode_ui_updates + _IMPORT_ERROR = None +except Exception as exc: # pragma: no cover - environment dependency guard + compute_mode_ui_updates = None + _IMPORT_ERROR = exc + +# Output indices for the two new state-clearing outputs +_IDX_AUDIO_CODES = 42 +_IDX_SRC_AUDIO = 43 +_EXPECTED_TUPLE_LENGTH = 44 + + +@unittest.skipIf(compute_mode_ui_updates is None, + f"compute_mode_ui_updates import unavailable: {_IMPORT_ERROR}") +class ModeUiStateClearingTests(unittest.TestCase): + """Tests that mode switches clear stale UI state to prevent noise.""" + + def test_tuple_length(self): + """compute_mode_ui_updates should return exactly 44 elements.""" + result = compute_mode_ui_updates("Custom") + self.assertEqual(len(result), _EXPECTED_TUPLE_LENGTH) + + def test_custom_mode_preserves_audio_codes(self): + """In Custom mode, audio_codes textbox should be visible but not cleared.""" + result = compute_mode_ui_updates("Custom") + codes_update = result[_IDX_AUDIO_CODES] + # Should only set visibility, not clear the value + self.assertTrue(codes_update.get("visible")) + self.assertNotIn("value", codes_update) + + def test_remix_mode_clears_audio_codes(self): + """Switching to Remix should clear the audio_codes textbox value.""" + result = compute_mode_ui_updates("Remix", previous_mode="Custom") + codes_update = result[_IDX_AUDIO_CODES] + self.assertEqual(codes_update.get("value"), "") + self.assertFalse(codes_update.get("visible")) + + def test_simple_mode_clears_audio_codes(self): + """Switching to Simple should clear the audio_codes textbox value.""" + result = compute_mode_ui_updates("Simple", previous_mode="Custom") + codes_update = result[_IDX_AUDIO_CODES] + self.assertEqual(codes_update.get("value"), "") + + def test_repaint_mode_clears_audio_codes(self): + """Switching to Repaint should clear the audio_codes textbox value.""" + result = compute_mode_ui_updates("Repaint", previous_mode="Custom") + codes_update = result[_IDX_AUDIO_CODES] + self.assertEqual(codes_update.get("value"), "") + + def test_custom_mode_clears_src_audio(self): + """Switching to Custom should clear src_audio (no source audio needed).""" + result = compute_mode_ui_updates("Custom", previous_mode="Remix") + src_update = result[_IDX_SRC_AUDIO] + self.assertIsNone(src_update.get("value")) + + def test_simple_mode_clears_src_audio(self): + """Switching to Simple should clear src_audio.""" + result = compute_mode_ui_updates("Simple", previous_mode="Remix") + src_update = result[_IDX_SRC_AUDIO] + self.assertIsNone(src_update.get("value")) + + def test_remix_mode_preserves_src_audio(self): + """In Remix mode, src_audio should not be cleared (it's needed).""" + result = compute_mode_ui_updates("Remix") + src_update = result[_IDX_SRC_AUDIO] + # Should be a no-op update (no value key) + self.assertNotIn("value", src_update) + + def test_repaint_mode_preserves_src_audio(self): + """In Repaint mode, src_audio should not be cleared (it's needed).""" + result = compute_mode_ui_updates("Repaint") + src_update = result[_IDX_SRC_AUDIO] + self.assertNotIn("value", src_update) + + def test_round_trip_remix_to_custom_clears_both(self): + """Switching Remix -> Custom should clear both codes and src_audio.""" + result = compute_mode_ui_updates("Custom", previous_mode="Remix") + codes_update = result[_IDX_AUDIO_CODES] + src_update = result[_IDX_SRC_AUDIO] + # Custom mode should not clear codes (it uses them) + self.assertTrue(codes_update.get("visible")) + # But src_audio should be cleared + self.assertIsNone(src_update.get("value")) + + def test_round_trip_custom_to_remix_clears_codes(self): + """Switching Custom -> Remix should clear stale audio codes.""" + result = compute_mode_ui_updates("Remix", previous_mode="Custom") + codes_update = result[_IDX_AUDIO_CODES] + self.assertEqual(codes_update.get("value"), "") + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/ui/gradio/events/results/audio_transfer.py b/acestep/ui/gradio/events/results/audio_transfer.py index eb7bffff..dcea1e31 100644 --- a/acestep/ui/gradio/events/results/audio_transfer.py +++ b/acestep/ui/gradio/events/results/audio_transfer.py @@ -73,9 +73,9 @@ def send_audio_to_remix(audio_file, lm_metadata, current_lyrics, current_caption llm_handler: Optional LLM handler. Returns: - 46-tuple of Gradio updates (4 data + 42 mode-UI). + 48-tuple of Gradio updates (4 data + 44 mode-UI). """ - n_outputs = 46 + n_outputs = 48 if audio_file is None: return (gr.skip(),) * n_outputs @@ -103,9 +103,9 @@ def send_audio_to_repaint(audio_file, lm_metadata, current_lyrics, current_capti llm_handler: Optional LLM handler. Returns: - 46-tuple of Gradio updates (4 data + 42 mode-UI). + 48-tuple of Gradio updates (4 data + 44 mode-UI). """ - n_outputs = 46 + n_outputs = 48 if audio_file is None: return (gr.skip(),) * n_outputs diff --git a/acestep/ui/gradio/events/results/generation_progress.py b/acestep/ui/gradio/events/results/generation_progress.py index 73dd739e..56e77ddb 100644 --- a/acestep/ui/gradio/events/results/generation_progress.py +++ b/acestep/ui/gradio/events/results/generation_progress.py @@ -93,6 +93,12 @@ def generate_with_progress( if task_type == "text2music": src_audio = None + # Defensive guard: cover/repaint/extract/lego tasks should never use + # stale audio codes from the text2music_audio_code_string textbox. + # Only text2music (Custom mode) with thinking disabled should pass codes. + if task_type != "text2music": + text2music_audio_code_string = "" + gen_params = GenerationParams( task_type=task_type, instruction=instruction_display_gen,