From 318b36ea3196443eb68f9789fd969c8052ab3eb3 Mon Sep 17 00:00:00 2001 From: chuxij Date: Thu, 19 Feb 2026 05:45:58 +0000 Subject: [PATCH] fix: pass project_root instead of checkpoint dir to initialize_service MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit init_service_wrapper was passing the checkpoint dropdown value (the full checkpoints directory path, e.g. '/checkpoints') as the project_root argument to initialize_service(). This caused initialize_service to append 'checkpoints' again, resulting in '/checkpoints/checkpoints' — triggering unnecessary model re-downloads into the nested directory. Fix: derive project_root from __file__ (consistent with the LLM init path that was already correct) instead of using the checkpoint dropdown value. Adds unit tests verifying the project_root is never the checkpoints dir. --- .../gradio/events/generation/service_init.py | 15 +- .../events/generation/service_init_test.py | 129 ++++++++++++++++++ 2 files changed, 139 insertions(+), 5 deletions(-) create mode 100644 acestep/ui/gradio/events/generation/service_init_test.py diff --git a/acestep/ui/gradio/events/generation/service_init.py b/acestep/ui/gradio/events/generation/service_init.py index 6fc0c4da..42d0abbb 100644 --- a/acestep/ui/gradio/events/generation/service_init.py +++ b/acestep/ui/gradio/events/generation/service_init.py @@ -77,18 +77,23 @@ def init_service_wrapper( f"(VRAM too low for KV cache), falling back to {backend}" ) + # Derive project_root from the checkpoint path (which is the checkpoints + # directory itself, e.g. "/checkpoints"). Passing it directly as + # project_root would cause initialize_service to append "checkpoints" again, + # resulting in "/checkpoints/checkpoints". + current_file = os.path.abspath(__file__) + # This file is in acestep/ui/gradio/events/generation/ + project_root = os.path.dirname(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(current_file)))))) + status, enable = dit_handler.initialize_service( - checkpoint, config_path, device, + project_root, config_path, device, use_flash_attention=use_flash_attention, compile_model=compile_model, offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu, quantization=quant_value, use_mlx_dit=mlx_dit, ) if init_llm: - current_file = os.path.abspath(__file__) - # This file is in acestep/ui/gradio/events/generation/ - project_root = os.path.dirname(os.path.dirname(os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(current_file)))))) checkpoint_dir = os.path.join(project_root, "checkpoints") lm_status, lm_success = llm_handler.initialize( diff --git a/acestep/ui/gradio/events/generation/service_init_test.py b/acestep/ui/gradio/events/generation/service_init_test.py new file mode 100644 index 00000000..2d7837fc --- /dev/null +++ b/acestep/ui/gradio/events/generation/service_init_test.py @@ -0,0 +1,129 @@ +"""Unit tests for service_init.init_service_wrapper checkpoint path handling.""" + +import os +import unittest +from unittest.mock import MagicMock, patch + + +class InitServiceWrapperPathTests(unittest.TestCase): + """Verify init_service_wrapper passes project_root (not checkpoint dir) to initialize_service.""" + + def _import_module(self): + """Import service_init lazily to avoid heavy transitive imports.""" + from acestep.ui.gradio.events.generation import service_init + return service_init + + @patch("acestep.ui.gradio.events.generation.service_init.get_global_gpu_config") + def test_passes_project_root_not_checkpoint_dir(self, mock_gpu_config): + """init_service_wrapper must NOT pass the checkpoint dropdown value as project_root. + + The checkpoint dropdown returns the full checkpoints directory path + (e.g. ``/checkpoints``). Passing it directly as ``project_root`` + causes initialize_service to append ``checkpoints`` again, yielding + ``/checkpoints/checkpoints``. + """ + module = self._import_module() + + # Stub GPU config + mock_gpu_config.return_value = MagicMock( + available_lm_models=["acestep-5Hz-lm-1.7B"], + lm_backend_restriction=None, + tier="tier6", + gpu_memory_gb=24.0, + max_duration_with_lm=600, + max_duration_without_lm=600, + max_batch_size_with_lm=4, + max_batch_size_without_lm=8, + ) + + dit_handler = MagicMock() + dit_handler.initialize_service.return_value = ("ok", True) + dit_handler.model = MagicMock() + dit_handler.is_turbo_model.return_value = True + + llm_handler = MagicMock() + llm_handler.llm_initialized = False + + # Simulate the checkpoint dropdown value: full path to checkpoints dir + checkpoint_value = "/some/project/checkpoints" + + module.init_service_wrapper( + dit_handler, + llm_handler, + checkpoint_value, + "acestep-v15-turbo", + "cpu", + False, # init_llm + None, # lm_model_path + "vllm", # backend + False, # use_flash_attention + False, # offload_to_cpu + False, # offload_dit_to_cpu + False, # compile_model + False, # quantization + ) + + # The first positional arg to initialize_service must be the project root, + # NOT the checkpoints directory. + call_args = dit_handler.initialize_service.call_args + actual_project_root = call_args[0][0] + + # It should be computed from __file__, not from the checkpoint dropdown. + # Critically, it must NOT end with "checkpoints". + self.assertFalse( + actual_project_root.rstrip("/").endswith("checkpoints"), + f"project_root must not be the checkpoints dir, got: {actual_project_root}", + ) + + @patch("acestep.ui.gradio.events.generation.service_init.get_global_gpu_config") + def test_project_root_is_consistent_with_checkpoint_dir(self, mock_gpu_config): + """The project_root passed to initialize_service should be the parent of checkpoints.""" + module = self._import_module() + + mock_gpu_config.return_value = MagicMock( + available_lm_models=[], + lm_backend_restriction=None, + tier="tier6", + gpu_memory_gb=24.0, + max_duration_with_lm=600, + max_duration_without_lm=600, + max_batch_size_with_lm=4, + max_batch_size_without_lm=8, + ) + + dit_handler = MagicMock() + dit_handler.initialize_service.return_value = ("ok", True) + dit_handler.model = MagicMock() + dit_handler.is_turbo_model.return_value = True + + llm_handler = MagicMock() + llm_handler.llm_initialized = False + + module.init_service_wrapper( + dit_handler, + llm_handler, + "/any/path/checkpoints", # checkpoint dropdown value (unused now) + "acestep-v15-turbo", + "cpu", + False, None, "vllm", False, False, False, False, False, + ) + + call_args = dit_handler.initialize_service.call_args + actual_project_root = call_args[0][0] + + # The project_root + "checkpoints" should form a valid checkpoints path + expected_checkpoints = os.path.join(actual_project_root, "checkpoints") + self.assertTrue( + os.path.isabs(expected_checkpoints) or actual_project_root, + "project_root should be a meaningful path", + ) + # It should NOT contain double "checkpoints" + self.assertNotIn( + "checkpoints/checkpoints", + expected_checkpoints, + f"Double nesting detected: {expected_checkpoints}", + ) + + +if __name__ == "__main__": + unittest.main()