From 63c137e186af1e1807cc07dcfcafbd614b1e21aa Mon Sep 17 00:00:00 2001 From: Stefan Tatschner Date: Mon, 9 Dec 2024 15:22:23 +0100 Subject: [PATCH] feat: Allow per scanner logging config The following setup is now possible when scanners are run in own scripts: import gallia import gallia.log import gallia.command logger = gallia.log.get_logger("gallia") class Scanner1(gallia.command.Script): def main(self) -> None: logger.info(f"hi {self.__class__.__name__}") logger.error("error") logger.warning("warning") logger.notice("notice") logger.info("info") logger.debug("debug") logger.trace("trace") class Scanner2(gallia.command.Script): def main(self) -> None: logger.info(f"hi {self.__class__.__name__}") if __name__ == "__main__": # Each scanner sets up its own logging setup # with a logger called `gallia`. Scanner1().entry_point() Scanner2().entry_point() # Alternatively, a context manager can be used # for more fine grained control. with gallia.log.setup_logging( logger_name="gallia", stderr_level=gallia.log.Loglevel.DEBUG, logfile="test.log.zst", ) as h: Scanner1(logging_handler=h).entry_point() Scanner2(logging_handler=h).entry_point() --- src/gallia/cli/gallia.py | 12 +- src/gallia/command/base.py | 37 ++-- src/gallia/log.py | 331 ++++++++++++++++++++------------ src/gallia/utils.py | 21 -- tests/pytest/test_helpers.py | 2 +- tests/pytest/test_transports.py | 2 +- 6 files changed, 241 insertions(+), 164 deletions(-) diff --git a/src/gallia/cli/gallia.py b/src/gallia/cli/gallia.py index 90f764a51..b184a2170 100644 --- a/src/gallia/cli/gallia.py +++ b/src/gallia/cli/gallia.py @@ -25,9 +25,9 @@ from gallia.plugins.plugin import CommandTree, load_commands, load_plugins from gallia.pydantic_argparse import ArgumentParser from gallia.pydantic_argparse import BaseCommand as PydanticBaseCommand -from gallia.utils import get_log_level +from gallia.log import get_log_level -setup_logging(Loglevel.DEBUG) +setup_logging("gallia", Loglevel.DEBUG) defaults = dict[type, dict[str, Any]] @@ -123,7 +123,6 @@ def get_command(config: BaseCommandConfig) -> BaseCommand: def parse_and_run( commands: type[BaseCommand] | MutableMapping[str, CommandTree | type[BaseCommand]], auto_complete: bool = True, - setup_log: bool = True, top_level_options: Mapping[str, Callable[[], None]] | None = None, show_help_on_zero_args: bool = True, ) -> Never: @@ -136,7 +135,6 @@ def parse_and_run( :param commands: A hierarchy of commands. :param auto_complete: Turns auto-complete functionality on. - :param setup_log: Setup logging according to the parameters in the parsed config. :param top_level_options: Optional top-level actions, such as "--version", given by a mapping of arguments and functions. The program redirects control to the given function, once the program is called with the corresponding argument and terminates after it returns. @@ -182,12 +180,6 @@ def __call__( assert isinstance(config, BaseCommandConfig) - if setup_log: - setup_logging( - level=get_log_level(config.verbose), - no_volatile_info=not config.volatile_info, - ) - sys.exit(get_command(config).entry_point()) diff --git a/src/gallia/command/base.py b/src/gallia/command/base.py index 8731fbe66..180ea802b 100644 --- a/src/gallia/command/base.py +++ b/src/gallia/command/base.py @@ -26,12 +26,12 @@ from gallia.command.config import Field, GalliaBaseModel, Idempotent from gallia.db.handler import DBHandler from gallia.dumpcap import Dumpcap -from gallia.log import add_zst_log_handler, get_logger, tz +from gallia.log import get_logger, tz, LoggingSetupHandler, get_log_level, setup_logging, Loglevel from gallia.power_supply import PowerSupply from gallia.power_supply.uri import PowerSupplyURI from gallia.services.uds.core.exception import UDSException from gallia.transports import BaseTransport, TargetURI -from gallia.utils import camel_to_snake, get_file_log_level +from gallia.utils import camel_to_snake @unique @@ -180,9 +180,11 @@ class BaseCommand(FlockMixin, ABC): #: a log message with level critical is logged. CATCHED_EXCEPTIONS: list[type[Exception]] = [] - log_file_handlers: list[Handler] - - def __init__(self, config: BaseCommandConfig) -> None: + def __init__( + self, + config: BaseCommandConfig = BaseCommandConfig(), + logging_handler: LoggingSetupHandler | None = None, + ) -> None: self.id = camel_to_snake(self.__class__.__name__) self.config = config self.artifacts_dir = Path() @@ -195,7 +197,7 @@ def __init__(self, config: BaseCommandConfig) -> None: ) self._lock_file_fd: int | None = None self.db_handler: DBHandler | None = None - self.log_file_handlers = [] + self.provided_logging_handler = logging_handler @abstractmethod def run(self) -> int: ... @@ -323,15 +325,25 @@ def entry_point(self) -> int: if self.HAS_ARTIFACTS_DIR: self.artifacts_dir = self.prepare_artifactsdir( - self.config.artifacts_base, self.config.artifacts_dir + self.config.artifacts_base, + self.config.artifacts_dir, + ) + + if self.provided_logging_handler is None: + stderr_level = get_log_level(self.config.verbose) + logging_handler = setup_logging( + logger_name="gallia", + stderr_level=stderr_level, + close_on_exit=False, ) - self.log_file_handlers.append( - add_zst_log_handler( + if self.HAS_ARTIFACTS_DIR: + logging_handler.add_zst_file_handler( logger_name="gallia", filepath=self.artifacts_dir.joinpath(FileNames.LOGFILE.value), - file_log_level=get_file_log_level(self.config), + log_level=stderr_level if self.config.trace_log is False else Loglevel.TRACE, ) - ) + else: + logging_handler = self.provided_logging_handler if self.config.hooks: self.run_hook(HookVariant.PRE) @@ -380,6 +392,9 @@ def entry_point(self) -> int: if self._lock_file_fd is not None: self._release_flock() + if self.provided_logging_handler is None: + logging_handler.stop_logging() + return exit_code diff --git a/src/gallia/log.py b/src/gallia/log.py index fdcbdf96f..b6b902f4e 100644 --- a/src/gallia/log.py +++ b/src/gallia/log.py @@ -25,7 +25,7 @@ from pathlib import Path from queue import Queue from types import TracebackType -from typing import TYPE_CHECKING, Any, BinaryIO, Self, TextIO, TypeAlias, cast +from typing import TYPE_CHECKING, Any, BinaryIO, Self, TextIO, TypeAlias, cast, IO, overload, Literal import zstandard @@ -36,42 +36,6 @@ gmt_offset = time.localtime().tm_gmtoff tz = datetime.timezone(datetime.timedelta(seconds=gmt_offset)) - -@unique -class ColorMode(Enum): - """ColorMode is used as an argument to :func:`set_color_mode`.""" - - #: Colors are always turned on. - ALWAYS = "always" - #: Colors are turned off if the target - #: stream (e.g. stderr) is not a tty. - AUTO = "auto" - #: No colors are used. In other words, - #: no ANSI escape codes are included. - NEVER = "never" - - -def resolve_color_mode(mode: ColorMode, stream: TextIO = sys.stderr) -> bool: - """Sets the color mode of the console log handler. - - :param mode: The available options are described in :class:`ColorMode`. - :param stream: Used as a reference for :attr:`ColorMode.AUTO`. - """ - if sys.platform == "win32": - return False - - match mode: - case ColorMode.ALWAYS: - return True - case ColorMode.AUTO: - if os.getenv("NO_COLOR") is not None: - return False - else: - return stream.isatty() - case ColorMode.NEVER: - return False - - # https://stackoverflow.com/a/35804945 def _add_logging_level(level_name: str, level_num: int) -> None: method_name = level_name.lower() @@ -108,6 +72,41 @@ def to_root(message, *args, **kwargs): # type: ignore _add_logging_level("NOTICE", 25) +@unique +class ColorMode(Enum): + """ColorMode is used as an argument to :func:`set_color_mode`.""" + + #: Colors are always turned on. + ALWAYS = "always" + #: Colors are turned off if the target + #: stream (e.g. stderr) is not a tty. + AUTO = "auto" + #: No colors are used. In other words, + #: no ANSI escape codes are included. + NEVER = "never" + + +def resolve_color_mode(mode: ColorMode, stream: TextIO = sys.stderr) -> bool: + """Sets the color mode of the console log handler. + + :param mode: The available options are described in :class:`ColorMode`. + :param stream: Used as a reference for :attr:`ColorMode.AUTO`. + """ + if sys.platform == "win32": + return False + + match mode: + case ColorMode.ALWAYS: + return True + case ColorMode.AUTO: + if os.getenv("NO_COLOR") is not None: + return False + else: + return stream.isatty() + case ColorMode.NEVER: + return False + + @unique class Loglevel(IntEnum): """A wrapper around the constants exposed by python's @@ -228,103 +227,163 @@ def to_level(self) -> Loglevel: raise ValueError("invalid value") +class LoggingSetupHandler: + def __init__(self) -> None: + self.listeners: list[QueueListener] = [] + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.stop_logging() + + def add_stream_handler( + self, + logger_name: str, + level: Loglevel, + stream: IO[str], + volatile_info: bool, + colored: bool, + ) -> None: + queue: Queue[logging.LogRecord] = Queue() + logger = logging.getLogger(logger_name) + logger.addHandler(QueueHandler(queue)) + + handler = logging.StreamHandler(stream) + handler.setLevel(level) + + formatter: logging.Formatter + if stream.isatty(): + formatter = _ConsoleFormatter(colored=colored, volatile_info=volatile_info) + else: + formatter = _StreamFormatter() + + handler.terminator = "" # We manually handle the terminator while formatting + + handler.setFormatter(formatter) + + listener = QueueListener( + queue, + handler, + respect_handler_level=True, + ) + listener.start() + self.listeners.append(listener) + + def add_zst_file_handler( + self, + logger_name: str, + filepath: Path | str, + log_level: Loglevel, + ) -> None: + queue: Queue[Any] = Queue() + logger = get_logger(logger_name) + logger.addHandler(QueueHandler(queue)) + + handler = _ZstdFileHandler( + filepath, + level=log_level, + ) + handler.setLevel(log_level) + handler.setFormatter(_JSONFormatter()) + + queue_listener = QueueListener( + queue, + handler, + respect_handler_level=True, + ) + queue_listener.start() + self.listeners.append(queue_listener) + + def stop_logging(self) -> None: + for listener in self.listeners: + listener.stop() + + +def get_log_level(verbose: int) -> Loglevel: + level = Loglevel.INFO + if verbose == 1: + level = Loglevel.DEBUG + elif verbose >= 2: + level = Loglevel.TRACE + return level + + +@overload def setup_logging( - level: Loglevel | None = None, - color_mode: ColorMode = ColorMode.AUTO, - no_volatile_info: bool = False, - logger_name: str = "gallia", + logger_name: str, + stderr_level: Loglevel | None = ..., + color_mode: ColorMode = ..., + volatile_info: bool = ..., + close_on_exit: Literal[False] = False, + logfile: Path | str | None = ..., + logfile_level: Loglevel = ..., +) -> LoggingSetupHandler: + ... + + +@overload +def setup_logging( + logger_name: str, + stderr_level: Loglevel | None = ..., + color_mode: ColorMode = ..., + volatile_info: bool = ..., + close_on_exit: Literal[True] = ..., + logfile: Path | str| None = ..., + logfile_level: Loglevel = ..., ) -> None: + ... + + +def setup_logging( + logger_name: str, + stderr_level: Loglevel | None = Loglevel.INFO, + color_mode: ColorMode = ColorMode.AUTO, + volatile_info: bool = False, # deprecated: Introduce progress info + close_on_exit: bool = False, + logfile: Path | str| None = None, + logfile_level: Loglevel = Loglevel.INFO, +) -> LoggingSetupHandler | None: """Enable and configure gallia's logging system. If this fuction is not called as early as possible, the logging system is in an undefined state und might not behave as expected. Always use this function to - initialize gallia's logging. For instance, ``setup_logging()`` - initializes a QueueHandler to avoid blocking calls during - logging. - - :param level: The loglevel to enable for the console handler. - If this argument is None, the env variable - ``GALLIA_LOGLEVEL`` (see :doc:`../env`) is read. - :param file_level: The loglevel to enable for the file handler. - :param path: The path to the logfile containing json records. - :param color_mode: The color mode to use for the console. + initialize gallia's logging. """ - if level is None: - # FIXME: why is this here and not in config? - if (raw := os.getenv("GALLIA_LOGLEVEL")) is not None: - level = PenlogPriority.from_str(raw).to_level() - else: - level = Loglevel.DEBUG - # These are slow and not used by gallia. logging.logMultiprocessing = False logging.logThreads = False logging.logProcesses = False logger = logging.getLogger(logger_name) - # LogLevel cannot be 0 (NOTSET), because only the root logger sends it to its handlers then + + # FIXME: Randomly setting loglevels seems wrong. Address this better. logger.setLevel(1) - # Clean up potentially existing handlers and create a new async QueueHandler for stderr output - while len(logger.handlers) > 0: - logger.handlers[0].close() - logger.removeHandler(logger.handlers[0]) - colored = resolve_color_mode(color_mode) - add_stderr_log_handler(logger_name, level, no_volatile_info, colored) + # Clean up potentially existing handlers. + for h in logger.handlers[:]: + logger.removeHandler(h) + h.close() + handler = LoggingSetupHandler() -def add_stderr_log_handler( - logger_name: str, - level: Loglevel, - no_volatile_info: bool, - colored: bool, -) -> None: - queue: Queue[Any] = Queue() - logger = logging.getLogger(logger_name) - logger.addHandler(QueueHandler(queue)) - - stderr_handler = logging.StreamHandler(sys.stderr) - stderr_handler.setLevel(level) - console_formatter = _ConsoleFormatter() - - console_formatter.colored = colored - stderr_handler.terminator = "" # We manually handle the terminator while formatting - if no_volatile_info is False: - console_formatter.volatile_info = True - - stderr_handler.setFormatter(console_formatter) - - queue_listener = QueueListener( - queue, - *[stderr_handler], - respect_handler_level=True, - ) - queue_listener.start() - atexit.register(queue_listener.stop) - - -def add_zst_log_handler( - logger_name: str, filepath: Path, file_log_level: Loglevel -) -> logging.Handler: - queue: Queue[Any] = Queue() - logger = get_logger(logger_name) - logger.addHandler(QueueHandler(queue)) - - zstd_handler = _ZstdFileHandler( - filepath, - level=file_log_level, - ) - zstd_handler.setLevel(file_log_level) - zstd_handler.setFormatter(_JSONFormatter()) - - queue_listener = QueueListener( - queue, - *[zstd_handler], - respect_handler_level=True, - ) - queue_listener.start() - atexit.register(queue_listener.stop) - return zstd_handler + if stderr_level is not None: + colored = resolve_color_mode(color_mode, sys.stderr) + handler.add_stream_handler(logger_name, stderr_level, sys.stderr, volatile_info, colored=colored) + + if logfile: + handler.add_zst_file_handler(logger_name, logfile, logfile_level) + + if close_on_exit: + atexit.register(handler.stop_logging) + return None + + return handler @dataclasses.dataclass @@ -705,9 +764,41 @@ def format(self, record: logging.LogRecord) -> str: return json.dumps(dataclasses.asdict(penlog_record)) +class _StreamFormatter(logging.Formatter): + def __init__(self) -> None: + pass + + def format( + self, + record: logging.LogRecord, + ) -> str: + stacktrace = None + + if record.exc_info: + exc_type, exc_value, exc_traceback = record.exc_info + assert exc_type + assert exc_value + assert exc_traceback + + stacktrace = "\n" + stacktrace += "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + + return _format_record( + dt=datetime.datetime.fromtimestamp(record.created), + name=record.name, + data=record.getMessage(), + levelno=record.levelno, + tags=record.__dict__["tags"] if "tags" in record.__dict__ else None, + stacktrace=stacktrace, + colored=False, + volatile_info=False, + ) + + class _ConsoleFormatter(logging.Formatter): - colored: bool = False - volatile_info: bool = False + def __init__(self, colored: bool, volatile_info: bool) -> None: + self.colored = colored + self.volatile_info = volatile_info # deprecated: will be removed def format( self, @@ -737,7 +828,7 @@ def format( class _ZstdFileHandler(logging.Handler): - def __init__(self, path: Path, level: int | str = logging.NOTSET) -> None: + def __init__(self, path: Path | str, level: int | str = logging.NOTSET) -> None: super().__init__(level) self.file = zstandard.open( filename=path, diff --git a/src/gallia/utils.py b/src/gallia/utils.py index 336bee5bc..4d6792d2c 100644 --- a/src/gallia/utils.py +++ b/src/gallia/utils.py @@ -261,27 +261,6 @@ def dump_args(args: Any) -> dict[str, str | int | float]: return settings -def get_log_level(args: Any) -> Loglevel: - level = Loglevel.INFO - if hasattr(args, "verbose"): - if args.verbose == 1: - level = Loglevel.DEBUG - elif args.verbose >= 2: - level = Loglevel.TRACE - return level - - -def get_file_log_level(args: Any) -> Loglevel: - level = Loglevel.DEBUG - if hasattr(args, "trace_log"): - if args.trace_log: - level = Loglevel.TRACE - elif hasattr(args, "verbose"): - if args.verbose >= 2: - level = Loglevel.TRACE - return level - - CONTEXT_SHARED_VARIABLE = "logger_name" context: contextvars.ContextVar[tuple[str, str | None]] = contextvars.ContextVar( CONTEXT_SHARED_VARIABLE diff --git a/tests/pytest/test_helpers.py b/tests/pytest/test_helpers.py index 57b762120..60c667618 100644 --- a/tests/pytest/test_helpers.py +++ b/tests/pytest/test_helpers.py @@ -11,7 +11,7 @@ ) from gallia.utils import split_host_port -setup_logging() +setup_logging("gallia") def test_split_host_port_v4() -> None: diff --git a/tests/pytest/test_transports.py b/tests/pytest/test_transports.py index 7c2f2eb5b..24ea6f0b4 100644 --- a/tests/pytest/test_transports.py +++ b/tests/pytest/test_transports.py @@ -15,7 +15,7 @@ test_data = [b"hello" b"tcp"] -setup_logging() +setup_logging("gallia") class TCPServer: