From c8c83ef823862e4b83ebf0dc0ef4fe0a5fce9876 Mon Sep 17 00:00:00 2001 From: Hami0095 Date: Sat, 24 Jan 2026 00:30:35 +0500 Subject: [PATCH] Refactor logging to follow Dependency Inversion Principle --- moshi/moshi/client_utils.py | 16 ++++++++++++++++ moshi/moshi/offline.py | 6 ++++-- moshi/moshi/server.py | 3 ++- moshi/moshi/utils/logging.py | 26 +++++++++++++++++++++++--- moshi/moshi/utils/logging_interface.py | 15 +++++++++++++++ 5 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 moshi/moshi/utils/logging_interface.py diff --git a/moshi/moshi/client_utils.py b/moshi/moshi/client_utils.py index 244b905..c038912 100644 --- a/moshi/moshi/client_utils.py +++ b/moshi/moshi/client_utils.py @@ -6,6 +6,8 @@ from dataclasses import dataclass import sys +from .utils.logging_interface import LoggingProvider +from .utils.logging import configure_logging def colorize(text, color): @@ -26,6 +28,20 @@ def make_log(level: str, msg: str) -> str: return prefix + " " + msg +class ClientLoggingProvider(LoggingProvider): + """High-level implementation of the LoggingProvider interface.""" + def colorize(self, text: str, color: str) -> str: + return colorize(text, color) + + def make_log(self, level: str, msg: str) -> str: + return make_log(level, msg) + + +def setup_client_logging(): + """Wires up the high-level logging implementation to the low-level utility.""" + configure_logging(ClientLoggingProvider()) + + class RawPrinter: def __init__(self, stream=sys.stdout, err_stream=sys.stderr): self.stream = stream diff --git a/moshi/moshi/offline.py b/moshi/moshi/offline.py index f690620..99d2d01 100644 --- a/moshi/moshi/offline.py +++ b/moshi/moshi/offline.py @@ -52,15 +52,16 @@ import sphn from huggingface_hub import hf_hub_download -from .client_utils import make_log +from .client_utils import setup_client_logging from .models import loaders, LMGen, MimiModel from .models.lm import load_audio as lm_load_audio from .models.lm import _iterate_audio as lm_iterate_audio from .models.lm import encode_from_sphn as lm_encode_from_sphn +from .utils.logging import print_log def log(level: str, msg: str): - print(make_log(level, msg)) + print_log(level, msg) def seed_all(seed: int): @@ -320,6 +321,7 @@ def run_inference( def main(): """Parse CLI args and run offline inference.""" + setup_client_logging() parser = argparse.ArgumentParser( description="Offline inference from WAV input using Moshi server components." ) diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 771f491..59f6ab7 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -45,7 +45,7 @@ import torch import random -from .client_utils import make_log, colorize +from .client_utils import setup_client_logging from .models import loaders, MimiModel, LMModel, LMGen from .utils.connection import create_ssl_context, get_lan_ip from .utils.logging import setup_logger, ColorizedLog @@ -355,6 +355,7 @@ def _get_static_path(static: Optional[str]) -> Optional[str]: def main(): + setup_client_logging() parser = argparse.ArgumentParser() parser.add_argument("--host", default="localhost", type=str) parser.add_argument("--port", default=8998, type=int) diff --git a/moshi/moshi/utils/logging.py b/moshi/moshi/utils/logging.py index 21b8ba5..3e73e5f 100644 --- a/moshi/moshi/utils/logging.py +++ b/moshi/moshi/utils/logging.py @@ -24,7 +24,26 @@ import random import string from typing import Optional -from ..client_utils import make_log, colorize +from .logging_interface import LoggingProvider + + +_provider: Optional[LoggingProvider] = None + + +def configure_logging(provider: LoggingProvider): + """Explicitly wire the logging provider implementation.""" + global _provider + _provider = provider + + +def _get_provider() -> LoggingProvider: + """Internal helper to get the active logging provider.""" + if _provider is None: + raise RuntimeError( + "Logging provider not configured. Call configure_logging() " + "during application startup." + ) + return _provider def random_id(n=4): @@ -52,7 +71,8 @@ def setup_logger(name: str, log_file=None, level=logging.INFO): def print_log(level: str, msg: str, prefix: Optional[str] = None, info_color: Optional[str] = None): - colorized_msg = make_log(level, msg) if info_color is None or level != "info" else colorize(msg, info_color) + provider = _get_provider() + colorized_msg = provider.make_log(level, msg) if info_color is None or level != "info" else provider.colorize(msg, info_color) if prefix is None: print(colorized_msg) else: @@ -71,5 +91,5 @@ def log(self, level: str, msg: str): def randomize(cls): cid = random_id() color = random.choice(["91", "92", "93", "94", "95", "96", "97"]) - prefix = colorize(f"[{cid}] ", color) + prefix = _get_provider().colorize(f"[{cid}] ", color) return cls(prefix=prefix, info_color=color) diff --git a/moshi/moshi/utils/logging_interface.py b/moshi/moshi/utils/logging_interface.py new file mode 100644 index 0000000..adbe8c7 --- /dev/null +++ b/moshi/moshi/utils/logging_interface.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +from typing import Protocol + +class LoggingProvider(Protocol): + """Abstraction for logging capabilities required by low-level modules.""" + + def colorize(self, text: str, color: str) -> str: + """Colorize text with ANSI codes.""" + ... + + def make_log(self, level: str, msg: str) -> str: + """Create a colorized log message with a level prefix.""" + ...