Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions moshi/moshi/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from dataclasses import dataclass
import sys
from .utils.logging_interface import LoggingProvider
from .utils.logging import configure_logging


def colorize(text, color):
Expand All @@ -26,6 +28,20 @@ def make_log(level: str, msg: str) -> str:
return prefix + " " + msg


class ClientLoggingProvider(LoggingProvider):
"""High-level implementation of the LoggingProvider interface."""
def colorize(self, text: str, color: str) -> str:
return colorize(text, color)

def make_log(self, level: str, msg: str) -> str:
return make_log(level, msg)


def setup_client_logging():
"""Wires up the high-level logging implementation to the low-level utility."""
configure_logging(ClientLoggingProvider())


class RawPrinter:
def __init__(self, stream=sys.stdout, err_stream=sys.stderr):
self.stream = stream
Expand Down
6 changes: 4 additions & 2 deletions moshi/moshi/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@
import sphn
from huggingface_hub import hf_hub_download

from .client_utils import make_log
from .client_utils import setup_client_logging
from .models import loaders, LMGen, MimiModel
from .models.lm import load_audio as lm_load_audio
from .models.lm import _iterate_audio as lm_iterate_audio
from .models.lm import encode_from_sphn as lm_encode_from_sphn
from .utils.logging import print_log


def log(level: str, msg: str):
print(make_log(level, msg))
print_log(level, msg)


def seed_all(seed: int):
Expand Down Expand Up @@ -320,6 +321,7 @@ def run_inference(

def main():
"""Parse CLI args and run offline inference."""
setup_client_logging()
parser = argparse.ArgumentParser(
description="Offline inference from WAV input using Moshi server components."
)
Expand Down
3 changes: 2 additions & 1 deletion moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import torch
import random

from .client_utils import make_log, colorize
from .client_utils import setup_client_logging
from .models import loaders, MimiModel, LMModel, LMGen
from .utils.connection import create_ssl_context, get_lan_ip
from .utils.logging import setup_logger, ColorizedLog
Expand Down Expand Up @@ -355,6 +355,7 @@ def _get_static_path(static: Optional[str]) -> Optional[str]:


def main():
setup_client_logging()
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="localhost", type=str)
parser.add_argument("--port", default=8998, type=int)
Expand Down
26 changes: 23 additions & 3 deletions moshi/moshi/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,26 @@
import random
import string
from typing import Optional
from ..client_utils import make_log, colorize
from .logging_interface import LoggingProvider


_provider: Optional[LoggingProvider] = None


def configure_logging(provider: LoggingProvider):
"""Explicitly wire the logging provider implementation."""
global _provider
_provider = provider


def _get_provider() -> LoggingProvider:
"""Internal helper to get the active logging provider."""
if _provider is None:
raise RuntimeError(
"Logging provider not configured. Call configure_logging() "
"during application startup."
)
return _provider


def random_id(n=4):
Expand Down Expand Up @@ -52,7 +71,8 @@ def setup_logger(name: str, log_file=None, level=logging.INFO):


def print_log(level: str, msg: str, prefix: Optional[str] = None, info_color: Optional[str] = None):
colorized_msg = make_log(level, msg) if info_color is None or level != "info" else colorize(msg, info_color)
provider = _get_provider()
colorized_msg = provider.make_log(level, msg) if info_color is None or level != "info" else provider.colorize(msg, info_color)
if prefix is None:
print(colorized_msg)
else:
Expand All @@ -71,5 +91,5 @@ def log(self, level: str, msg: str):
def randomize(cls):
cid = random_id()
color = random.choice(["91", "92", "93", "94", "95", "96", "97"])
prefix = colorize(f"[{cid}] ", color)
prefix = _get_provider().colorize(f"[{cid}] ", color)
return cls(prefix=prefix, info_color=color)
15 changes: 15 additions & 0 deletions moshi/moshi/utils/logging_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT

from typing import Protocol

class LoggingProvider(Protocol):
"""Abstraction for logging capabilities required by low-level modules."""

def colorize(self, text: str, color: str) -> str:
"""Colorize text with ANSI codes."""
...

def make_log(self, level: str, msg: str) -> str:
"""Create a colorized log message with a level prefix."""
...