diff --git a/modellogger/log_config.py b/modellogger/log_config.py index ef733e6..86f6852 100644 --- a/modellogger/log_config.py +++ b/modellogger/log_config.py @@ -1,7 +1,7 @@ import logging import sys import time -from logging import INFO +from typing import Any, Dict, Optional class DefaultFormatter(logging.Formatter): @@ -15,7 +15,7 @@ class DefaultFormatter(logging.Formatter): "RESET": "\x1b[0m", } - def __init__(self, app_name=".", include_colors=True): + def __init__(self, app_name: str = ".", include_colors: bool = True): super().__init__() self.app_name = app_name self.include_colors = include_colors @@ -23,7 +23,7 @@ def __init__(self, app_name=".", include_colors=True): "%(asctime)s - {app_name} - %(name)s - %(levelname)s - %(message)s" ) - def format(self, record): + def format(self, record: logging.LogRecord) -> str: date_format = "%Y-%m-%dT%H:%M:%SZ" log_fmt = self.base_format.format(app_name=self.app_name) @@ -36,12 +36,24 @@ def format(self, record): return formatter.format(record) -def configure_logging(app_name=".", level=INFO, log_file=None): +def configure_logging( + app_name: str = ".", + level: int = logging.INFO, + log_file: Optional[str] = None, +) -> None: + """Configures the root logger. Preserves existing handlers (if any). + + NOTE: if you call this function multiple times, you may end up with + duplicate log messages, since each call adds new handlers to the root logger. + + Args: + app_name: Name of the application to include in log messages. + level: Logging level to set. + log_file: If provided, log to this file instead of stderr. + """ logger = logging.getLogger() logger.setLevel(level) - logger.handlers.clear() - if log_file: file_handler = logging.FileHandler(log_file) file_handler.setFormatter(DefaultFormatter(app_name, include_colors=False)) @@ -52,13 +64,15 @@ def configure_logging(app_name=".", level=INFO, log_file=None): logger.addHandler(console_handler) -def get_logger(name): +def get_logger(name: str) -> logging.Logger: logger = logging.getLogger(name) return logger -def get_config_dict(app_name=".", log_file=None): - config = { +def get_config_dict( + app_name: str = ".", log_file: Optional[str] = None +) -> Dict[str, Any]: + config: Dict[str, Any] = { "version": 1, "disable_existing_loggers": False, "formatters": { diff --git a/modellogger/py.typed b/modellogger/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_log_config.py b/tests/test_log_config.py index 207878a..c99585e 100644 --- a/tests/test_log_config.py +++ b/tests/test_log_config.py @@ -50,7 +50,7 @@ def test_configure_logging(): root_logger = logging.getLogger() assert root_logger.name == "root" assert root_logger.level == logging.INFO - assert isinstance(root_logger.handlers[0], logging.StreamHandler) + assert isinstance(root_logger.handlers[-1], logging.StreamHandler) def test_configure_logging_with_file(tmp_path): @@ -60,7 +60,7 @@ def test_configure_logging_with_file(tmp_path): root_logger = logging.getLogger() assert root_logger.name == "root" assert root_logger.level == logging.INFO - assert isinstance(root_logger.handlers[0], logging.FileHandler) + assert isinstance(root_logger.handlers[-1], logging.FileHandler) def test_get_logger_basic():