Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions modellogger/log_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import sys
import time
from logging import INFO
from typing import Any, Dict, Optional


class DefaultFormatter(logging.Formatter):
Expand All @@ -15,15 +15,15 @@ class DefaultFormatter(logging.Formatter):
"RESET": "\x1b[0m",
}

def __init__(self, app_name=".", include_colors=True):
def __init__(self, app_name: str = ".", include_colors: bool = True):
super().__init__()
self.app_name = app_name
self.include_colors = include_colors
self.base_format = (
"%(asctime)s - {app_name} - %(name)s - %(levelname)s - %(message)s"
)

def format(self, record):
def format(self, record: logging.LogRecord) -> str:
date_format = "%Y-%m-%dT%H:%M:%SZ"
log_fmt = self.base_format.format(app_name=self.app_name)

Expand All @@ -36,12 +36,24 @@ def format(self, record):
return formatter.format(record)


def configure_logging(app_name=".", level=INFO, log_file=None):
def configure_logging(
app_name: str = ".",
level: int = logging.INFO,
log_file: Optional[str] = None,
) -> None:
"""Configures the root logger. Preserves existing handlers (if any).

NOTE: if you call this function multiple times, you may end up with
duplicate log messages, since each call adds new handlers to the root logger.

Args:
app_name: Name of the application to include in log messages.
level: Logging level to set.
log_file: If provided, log to this file instead of stderr.
"""
logger = logging.getLogger()
logger.setLevel(level)

logger.handlers.clear()

if log_file:
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(DefaultFormatter(app_name, include_colors=False))
Expand All @@ -52,13 +64,15 @@ def configure_logging(app_name=".", level=INFO, log_file=None):
logger.addHandler(console_handler)


def get_logger(name):
def get_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
return logger


def get_config_dict(app_name=".", log_file=None):
config = {
def get_config_dict(
app_name: str = ".", log_file: Optional[str] = None
) -> Dict[str, Any]:
config: Dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
Expand Down
Empty file added modellogger/py.typed
Empty file.
4 changes: 2 additions & 2 deletions tests/test_log_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_configure_logging():
root_logger = logging.getLogger()
assert root_logger.name == "root"
assert root_logger.level == logging.INFO
assert isinstance(root_logger.handlers[0], logging.StreamHandler)
assert isinstance(root_logger.handlers[-1], logging.StreamHandler)


def test_configure_logging_with_file(tmp_path):
Expand All @@ -60,7 +60,7 @@ def test_configure_logging_with_file(tmp_path):
root_logger = logging.getLogger()
assert root_logger.name == "root"
assert root_logger.level == logging.INFO
assert isinstance(root_logger.handlers[0], logging.FileHandler)
assert isinstance(root_logger.handlers[-1], logging.FileHandler)


def test_get_logger_basic():
Expand Down