diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index aff0867e9a..89660b11d3 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -13,6 +13,7 @@ from __future__ import annotations import abc +import contextlib import dataclasses import functools import types @@ -44,7 +45,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.gtcallable import GTCallable -from gt4py.next.instrumentation import metrics +from gt4py.next.instrumentation import hook_machinery, metrics from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, compiled_program, options, toolchain from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -53,6 +54,34 @@ DEFAULT_BACKEND: next_backend.Backend | None = None +ProgramCallMetricsCollector = metrics.make_collector( + level=metrics.MINIMAL, metric_name=metrics.TOTAL_METRIC +) + + +@hook_machinery.context_hook +def program_call_context( + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + enable_jit: bool, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + """Hook called at the beginning and end of a program call.""" + return ProgramCallMetricsCollector() + + +@hook_machinery.context_hook +def embedded_program_call_context( + program: Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + """Hook called at the beginning and end of an embedded program call.""" + return metrics.metrics_context(f"{program.__name__}<'')>") + + @dataclasses.dataclass(frozen=True) class _CompilableGTEntryPointMixin(Generic[ffront_stages.DSLDefinitionT]): """ @@ -162,11 +191,6 @@ def compile( return self -program_call_metrics_collector = metrics.make_collector( - level=metrics.MINIMAL, metric_name=metrics.TOTAL_METRIC -) - - # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. @dataclasses.dataclass(frozen=True) @@ -333,7 +357,13 @@ def __call__( offset_provider = {} enable_jit = self.compilation_options.enable_jit if enable_jit is None else enable_jit - with program_call_metrics_collector(): + with program_call_context( + program=self, + args=args, + offset_provider=offset_provider, + enable_jit=enable_jit, + kwargs=kwargs, + ): if __debug__: # TODO: remove or make dependency on self.past_stage optional past_process_args._validate_args( @@ -355,15 +385,9 @@ def __call__( stacklevel=2, ) - # Metrics source key needs to be set here. Embedded programs - # don't have variants so there's no other place to do it. - if metrics.is_level_enabled(metrics.MINIMAL): - metrics.set_current_source_key( - f"{self.__name__}<{getattr(self.backend, 'name', '')}>" - ) - with next_embedded.context.update(offset_provider=offset_provider): - self.definition_stage.definition(*args, **kwargs) + with embedded_program_call_context(self, args, offset_provider, kwargs): + self.definition_stage.definition(*args, **kwargs) try: diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index d6dbddd7c0..da05c68669 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -90,6 +90,7 @@ class PASTProgramDef: PASTProgramDef, arguments.CompileTimeArgs ] +DSLDefinition = DSLFieldOperatorDef | DSLProgramDef DSLDefinitionT = TypeVar("DSLDefinitionT", DSLFieldOperatorDef, DSLProgramDef) diff --git a/src/gt4py/next/instrumentation/hook_machinery.py b/src/gt4py/next/instrumentation/hook_machinery.py new file mode 100644 index 0000000000..15475efc54 --- /dev/null +++ b/src/gt4py/next/instrumentation/hook_machinery.py @@ -0,0 +1,201 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import ast +import collections.abc +import contextlib +import dataclasses +import inspect +import textwrap +import types +import typing +import warnings +from collections.abc import Callable +from typing import Generic, ParamSpec, TypeVar + + +P = ParamSpec("P") +T = TypeVar("T") + + +def _get_unique_name(func: Callable) -> str: + """Generate a unique name for a callable object.""" + return ( + f"{func.__module__}.{getattr(func, '__qualname__', func.__class__.__qualname__)}#{id(func)}" + ) + + +def _is_empty_function(func: Callable) -> bool: + """Check if a callable object is empty (i.e., contains no statements).""" + try: + callable_src = ( + inspect.getsource(func) + if isinstance(func, types.FunctionType) + else inspect.getsource(func.__call__) # type: ignore[operator] # asserted above + ) + callable_ast = ast.parse(textwrap.dedent(callable_src)) + return all( + isinstance(st, ast.Pass) + or (isinstance(st, ast.Expr) and isinstance(st.value, ast.Constant)) + for st in typing.cast(ast.FunctionDef, callable_ast.body[0]).body + ) + except Exception: + return False + + +@dataclasses.dataclass(slots=True) +class _BaseHook(Generic[T, P]): + """Base class to define callback registration functionality for all hook types.""" + + definition: Callable[P, T] + registry: dict[str, Callable[P, T]] = dataclasses.field(default_factory=dict, kw_only=True) + callbacks: tuple[Callable[P, T], ...] = dataclasses.field(default=(), init=False) + + @property + def __doc__(self) -> str | None: # type: ignore[override] + return self.definition.__doc__ + + def __post_init__(self) -> None: + # As an optimization to avoid an empty function call if no callbacks are + # registered, we only add the original definitions to the list of callables + # if it contains a non-empty definition. + if not _is_empty_function(self.definition): + self.callbacks = (self.definition,) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + raise NotImplementedError("This method should be implemented by subclasses.") + + def register( + self, callback: Callable[P, T], *, name: str | None = None, index: int | None = None + ) -> None: + """ + Register a callback to the hook. + + Args: + callback: The callable to register. + name: An optional name for the callback. If not provided, a unique name will be generated. + index: An optional index at which to insert the callback (not counting the original + definition). If not provided, the callback will be appended to the end of the list. + """ + + callable_signature = inspect.signature(callback) + hook_signature = inspect.signature(self.definition) + + signature_mismatch = len(callable_signature.parameters) != len( + hook_signature.parameters + ) or any( + # Remove the annotation before comparison to avoid false mismatches + actual_param.replace(annotation="") != expected_param.replace(annotation="") + for actual_param, expected_param in zip( + callable_signature.parameters.values(), hook_signature.parameters.values() + ) + ) + if signature_mismatch: + raise ValueError( + f"Callback signature {callable_signature} does not match hook signature {hook_signature}" + ) + try: + callable_typing = typing.get_type_hints(callback) + hook_typing = typing.get_type_hints(self.definition) + if not all( + callable_typing[arg_key] == arg_typing + for arg_key, arg_typing in hook_typing.items() + ): + warnings.warn( + f"Callback annotations {callable_typing} does not match expected hook annotations {hook_typing}", + stacklevel=2, + ) + except Exception: + # Ignore issues while checking type hints (e.g., forward references + # or missing imports); failure here should not prevent hook registration. + pass + + name = name or _get_unique_name(callback) + + if index is None: + self.callbacks += (callback,) + else: + if self.callbacks and self.callbacks[0] is self.definition: + index += 1 # The original definition should always go first + self.callbacks = (*self.callbacks[:index], callback, *self.callbacks[index:]) + + self.registry[name] = callback + + def remove(self, callback: str | Callable[P, T]) -> None: + """ + Remove a registered callback from the hook. + + Args: + callback: The callable object to remove or its registered name. + """ + if isinstance(callback, str): + name = callback + if name not in self.registry: + raise KeyError(f"No callback registered under the name '{name}'") + else: + name = _get_unique_name(callback) + if name not in self.registry: + raise KeyError(f"Callback object {callback} not found in registry") + + callback = self.registry.pop(name) + assert callback in self.callbacks + self.callbacks = tuple(cb for cb in self.callbacks if cb is not callback) + + +@dataclasses.dataclass(slots=True) +class EventHook(_BaseHook[None, P]): + """Event hook specification.""" + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: + for func in self.callbacks: + func(*args, **kwargs) + + +def event_hook(definition: Callable[P, None]) -> EventHook[P]: + """Decorator to create an EventHook from a function definition.""" + return EventHook(definition) + + +@dataclasses.dataclass(slots=True) +class ContextHook( + contextlib.AbstractContextManager, _BaseHook[contextlib.AbstractContextManager, P] +): + """ + Context hook specification. + + This hook type is used to define context managers that can be stacked together. + """ + + ctx_managers: collections.abc.Sequence[contextlib.AbstractContextManager] = dataclasses.field( + default=(), init=False + ) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> contextlib.AbstractContextManager: + self.ctx_managers = [func(*args, **kwargs) for func in self.callbacks] + return self + + def __enter__(self) -> None: + for ctx_manager in self.ctx_managers: + ctx_manager.__enter__() + + def __exit__( + self, + type_: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + for ctx_manager in reversed(self.ctx_managers): + ctx_manager.__exit__(type_, exc_value, traceback) + self.ctx_managers = () + + +def context_hook(definition: Callable[P, contextlib.AbstractContextManager]) -> ContextHook[P]: + """Decorator to create a ContextHook from a function definition.""" + return ContextHook(definition) diff --git a/src/gt4py/next/instrumentation/hooks.py b/src/gt4py/next/instrumentation/hooks.py new file mode 100644 index 0000000000..97d7f3a080 --- /dev/null +++ b/src/gt4py/next/instrumentation/hooks.py @@ -0,0 +1,18 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from gt4py.next.ffront.decorator import ( + embedded_program_call_context as embedded_program_call_context, + program_call_context as program_call_context, +) +from gt4py.next.otf.compiled_program import ( + compile_variant_hook as compile_variant_hook, + compiled_program_call_context as compiled_program_call_context, +) diff --git a/src/gt4py/next/instrumentation/metrics.py b/src/gt4py/next/instrumentation/metrics.py index 36b0f5567b..908f99ed61 100644 --- a/src/gt4py/next/instrumentation/metrics.py +++ b/src/gt4py/next/instrumentation/metrics.py @@ -12,6 +12,7 @@ import contextlib import contextvars import dataclasses +import functools import itertools import json import numbers @@ -150,28 +151,21 @@ def is_current_source_key_set() -> bool: return _source_key_cvar.get(_NO_KEY_SET_MARKER_) is not _NO_KEY_SET_MARKER_ -def get_current_source_key() -> str: - """Retrieve the current source key for metrics collection (it must be set).""" - return _source_key_cvar.get() - - -def set_current_source_key(key: str) -> Source: +def set_current_source_key(key: str) -> None: """ Set the current source key for metrics collection. It must be called only when no source key is set (or the same key is already set). - - Args: - key: The source key to set. - - Returns: - The `Source` object associated with the given key. """ assert _source_key_cvar.get(_NO_KEY_SET_MARKER_) in {key, _NO_KEY_SET_MARKER_}, ( "A different source key has been already set." ) _source_key_cvar.set(key) - return sources[key] + + +def get_current_source_key() -> str: + """Retrieve the current source key for metrics collection (it must be set).""" + return _source_key_cvar.get() def get_current_source() -> Source: @@ -210,14 +204,25 @@ def __enter__(self) -> None: def __exit__( self, exc_type_: type[BaseException] | None, - value: BaseException | None, + exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: if self.previous_cvar_token is not None: _source_key_cvar.reset(self.previous_cvar_token) +class SourceKeySetterAtEnter(SourceKeyContextManager): + def __exit__( + self, + exc_type_: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + pass + + metrics_context = SourceKeyContextManager +metrics_setter_at_enter = SourceKeySetterAtEnter @dataclasses.dataclass(slots=True) @@ -280,7 +285,7 @@ def __enter__(self) -> None: def __exit__( self, exc_type_: type[BaseException] | None, - value: BaseException | None, + exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: if self.previous_cvar_token is not None: @@ -292,6 +297,7 @@ def __exit__( _source_key_cvar.reset(self.previous_cvar_token) +@functools.cache def make_collector( level: int, metric_name: str, diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index c8399ea2ce..5791d02784 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -124,6 +124,12 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: return cls(args=args, kwargs=kwargs) +ArgStaticDescriptorsContext: TypeAlias = dict[str, MaybeNestedInTuple[ArgStaticDescriptor | None]] +ArgStaticDescriptorsContextsByType: TypeAlias = Mapping[ + type[ArgStaticDescriptor], ArgStaticDescriptorsContext +] + + @dataclasses.dataclass(frozen=True) class CompileTimeArgs: """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" @@ -136,10 +142,7 @@ class CompileTimeArgs: #: If an argument or element of an argument has no descriptor, the respective value is `None`. #: E.g., for a tuple argument `a` with type `ts.TupleTupe(types=[field_t, int32_t])` a possible # context would be `{"a": (FieldDomainDescriptor(...), None)}`. - argument_descriptor_contexts: Mapping[ - type[ArgStaticDescriptor], - dict[str, MaybeNestedInTuple[ArgStaticDescriptor | None]], - ] + argument_descriptor_contexts: ArgStaticDescriptorsContextsByType @property def offset_provider_type(self) -> common.OffsetProviderType: diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index bbb10c4610..2371b953dc 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -9,6 +9,7 @@ from __future__ import annotations import concurrent.futures +import contextlib import dataclasses import functools import itertools @@ -25,7 +26,7 @@ type_specifications as ts_ffront, type_translation, ) -from gt4py.next.instrumentation import metrics +from gt4py.next.instrumentation import hook_machinery, metrics from gt4py.next.otf import arguments, stages from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -34,17 +35,80 @@ T = TypeVar("T") ScalarOrTupleOfScalars: TypeAlias = xtyping.MaybeNestedInTuple[core_defs.Scalar] + +#: Content of the key: (*hashable_arg_descriptors, id(offset_provider), concrete_instantation_if_generic) CompiledProgramsKey: TypeAlias = tuple[tuple[Hashable, ...], int, None | str] -ArgumentDescriptors: TypeAlias = dict[ + +ArgStaticDescriptorsByType: TypeAlias = dict[ type[arguments.ArgStaticDescriptor], dict[str, arguments.ArgStaticDescriptor] ] -ArgumentDescriptorContext: TypeAlias = dict[ - str, xtyping.MaybeNestedInTuple[arguments.ArgStaticDescriptor | None] -] -ArgumentDescriptorContexts: TypeAlias = dict[ - type[arguments.ArgStaticDescriptor], - ArgumentDescriptorContext, -] + + +def _make_pool_root( + program_definition: ffront_stages.DSLDefinition, backend: gtx_backend.Backend +) -> tuple[str, str]: + return (program_definition.definition.__name__, backend.name) + + +@functools.cache +def _metrics_prefix_from_pool_root(root: tuple[str, str]) -> str: + """Generate a metrics prefix from a compiled programs pool root.""" + return f"{root[0]}<{root[1]}>" + + +@hook_machinery.event_hook +def compile_variant_hook( + program_definition: ffront_stages.DSLDefinition, + backend: gtx_backend.Backend, + offset_provider: common.OffsetProviderType | common.OffsetProvider, + argument_descriptors: ArgStaticDescriptorsByType, + key: CompiledProgramsKey, +) -> None: + """Callback hook invoked before compiling a program variant.""" + + if metrics.is_any_level_enabled(): + # Create a new metrics entity for this compiled program variant and + # attach relevant metadata to it. + source_key = f"{_metrics_prefix_from_pool_root(_make_pool_root(program_definition, backend))}[{hash(key)}]" + assert source_key not in metrics.sources, ( + "The key for the program variant being compiled is already set!!" + ) + + metrics.sources[source_key].metadata |= dict( + name=program_definition.definition.__name__, + backend=backend.name, + compiled_program_pool_key=hash(key), + **{ + f"{eve_utils.CaseStyleConverter.convert(key.__name__, 'pascal', 'snake')}s": value + for key, value in argument_descriptors.items() + }, + ) + + +@hook_machinery.context_hook +def compiled_program_call_context( + compiled_program: stages.CompiledProgram, + args: tuple[Any, ...], + kwargs: dict[str, Any], + offset_provider: common.OffsetProvider, + root: tuple[str, str], + key: CompiledProgramsKey, +) -> contextlib.AbstractContextManager: + """ + Hook called at the beginning and end of a compiled program call. + + Args: + compiled_program: The compiled program being called. + args: The arguments with which the program is called. + kwargs: The keyword arguments with which the program is called. + offset_provider: The offset provider passed to the program. + root: The root of the compiled programs pool this program belongs to, i.e. a tuple of + (program name, backend name). + key: The key of the compiled program in the compiled programs pool. + + """ + return metrics.metrics_setter_at_enter(f"{_metrics_prefix_from_pool_root(root)}[{hash(key)}]") + # TODO(havogt): We would like this to be a ProcessPoolExecutor, which requires (to decide what) to pickle. _async_compilation_pool: concurrent.futures.Executor | None = None @@ -122,11 +186,11 @@ def _make_argument_descriptors( argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]], args: tuple[Any], kwargs: dict[str, Any], -) -> ArgumentDescriptors: +) -> ArgStaticDescriptorsByType: """Given a set of runtime arguments construct all argument descriptors from them.""" func_type = program_type.definition params = list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) - descriptors: ArgumentDescriptors = {} + descriptors: ArgStaticDescriptorsByType = {} for descriptor_cls, exprs in argument_descriptor_mapping.items(): descriptors[descriptor_cls] = {} for expr in exprs: @@ -137,8 +201,8 @@ def _make_argument_descriptors( def _convert_to_argument_descriptor_context( - func_type: ts.FunctionType, argument_descriptors: ArgumentDescriptors -) -> ArgumentDescriptorContexts: + func_type: ts.FunctionType, argument_descriptors: ArgStaticDescriptorsByType +) -> arguments.ArgStaticDescriptorsContextsByType: """ Given argument descriptors, i.e., a mapping from an expr to a descriptor, transform them into a context of argument descriptors in which we can evaluate expressions. @@ -158,9 +222,9 @@ def _convert_to_argument_descriptor_context( >>> contexts[arguments.StaticArg] {'inp1': (None, StaticArg(value=1)), 'inp2': None} """ - descriptor_contexts: ArgumentDescriptorContexts = {} + descriptor_contexts: arguments.ArgStaticDescriptorsContextsByType = {} for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): - context: ArgumentDescriptorContext = _make_param_context_from_func_type( + context: arguments.ArgStaticDescriptorsContext = _make_param_context_from_func_type( func_type, lambda x: None ) # convert tuples to list such that we can alter the context easily @@ -190,14 +254,13 @@ def _convert_to_argument_descriptor_context( )(v) for k, v in context.items() } - descriptor_contexts[descriptor_cls] = context + descriptor_contexts[descriptor_cls] = context # type: ignore[index] # Hard to understand, it looks like a mypy bug return descriptor_contexts def _validate_argument_descriptors( - program_type: ts_ffront.ProgramType, - all_descriptors: ArgumentDescriptors, + program_type: ts_ffront.ProgramType, all_descriptors: ArgStaticDescriptorsByType ) -> None: for descriptors in all_descriptors.values(): for expr, descriptor in descriptors.items(): @@ -231,15 +294,19 @@ class CompiledProgramsPool(Generic[ffront_stages.DSLDefinitionT]): #: Note: The list is not ordered. argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] | None - # cache the compiled programs - compiled_programs: dict[ - CompiledProgramsKey, - stages.CompiledProgram | concurrent.futures.Future[stages.CompiledProgram], + # store for the compiled programs + compiled_programs: dict[CompiledProgramsKey, stages.CompiledProgram] = dataclasses.field( + default_factory=dict, init=False + ) + + # store for the async compilation jobs + _compilation_jobs: dict[ + CompiledProgramsKey, concurrent.futures.Future[stages.CompiledProgram] ] = dataclasses.field(default_factory=dict, init=False) @functools.cached_property - def _primitive_values_extractor(self) -> Callable | None: - return arguments.make_primitive_value_args_extractor(self.program_type.definition) + def root(self) -> tuple[str, str]: + return _make_pool_root(self.definition_stage, self.backend) def __post_init__(self) -> None: # TODO(havogt): We currently don't support pos_only or kw_only args at the program level. @@ -294,23 +361,12 @@ def __call__( ) try: - program = self.compiled_programs[key] - if metrics.is_level_enabled(metrics.MINIMAL): - metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) - - program(*args, **kwargs, offset_provider=offset_provider) # type: ignore[operator] # the Future case is handled below - - except TypeError as e: - if "program" in locals() and isinstance(program, concurrent.futures.Future): - # 'Future' objects are not callable so they will generate a TypeError. - # Here we resolve the future and call it again. - program = self._resolve_future(key) - program(*args, **kwargs, offset_provider=offset_provider) - else: - raise e + compiled_program = self.compiled_programs[key] except KeyError as e: - if enable_jit: + if self._finish_compilation_job(key): + compiled_program = self.compiled_programs[key] + elif enable_jit: assert self.argument_descriptor_mapping is not None self._compile_variant( argument_descriptors=_make_argument_descriptors( @@ -331,7 +387,18 @@ def __call__( enable_jit=False, **canonical_kwargs, ) # passing `enable_jit=False` because a cache miss should be a hard-error in this call` - raise RuntimeError("No program compiled for this set of static arguments.") from e + + else: + raise RuntimeError("No program compiled for this set of static arguments.") from e + + with compiled_program_call_context( + compiled_program, args, kwargs, offset_provider, self.root, key + ): + compiled_program(*args, **kwargs, offset_provider=offset_provider) + + @functools.cached_property + def _primitive_values_extractor(self) -> Callable | None: + return arguments.make_primitive_value_args_extractor(self.program_type.definition) @functools.cached_property def _is_generic(self) -> bool: @@ -359,12 +426,6 @@ def _args_canonicalizer(self) -> Callable[..., tuple[tuple, dict[str, Any]]]: self.program_type, name=self.definition_stage.definition.__name__ ) - @functools.cached_property - def _metrics_key_from_pool_key(self) -> Callable[[CompiledProgramsKey], str]: - prefix = f"{self.definition_stage.definition.__name__}<{self.backend.name}>" - - return lambda key: f"{prefix}[{hash(key)}]" - @functools.cached_property def _argument_descriptor_cache_key_from_args( self, @@ -388,7 +449,7 @@ def _argument_descriptor_cache_key_from_args( def _argument_descriptor_cache_key_from_descriptors( self, - argument_descriptor_contexts: ArgumentDescriptorContexts, + argument_descriptor_contexts: arguments.ArgStaticDescriptorsContextsByType, ) -> tuple: """ Given a set of argument descriptors deduce the cache key used to retrieve the instance @@ -412,7 +473,7 @@ def _argument_descriptor_cache_key_from_descriptors( return tuple(elements) def _initialize_argument_descriptor_mapping( - self, argument_descriptors: ArgumentDescriptors + self, argument_descriptors: ArgStaticDescriptorsByType ) -> None: if self.argument_descriptor_mapping is None: self.argument_descriptor_mapping = { @@ -449,9 +510,22 @@ def _validate_argument_descriptor_mapping(self) -> None: location=None, ) + def _is_existing_key(self, key: CompiledProgramsKey) -> bool: + return key in self.compiled_programs or key in self._compilation_jobs + + def _finish_compilation_job(self, key: CompiledProgramsKey) -> bool: + if key not in self._compilation_jobs: + return False + + compiled_program_future = self._compilation_jobs.pop(key) + assert isinstance(compiled_program_future, concurrent.futures.Future) + assert key not in self.compiled_programs + self.compiled_programs[key] = compiled_program_future.result() + return True + def _compile_variant( self, - argument_descriptors: ArgumentDescriptors, + argument_descriptors: ArgStaticDescriptorsByType, offset_provider: common.OffsetProviderType | common.OffsetProvider, #: tuple consisting of the types of the positional and keyword arguments. arg_specialization_info: tuple[tuple[ts.TypeSpec, ...], dict[str, ts.TypeSpec]] @@ -481,22 +555,9 @@ def _compile_variant( ) assert call_key is None or call_key == key - if key in self.compiled_programs: + if self._is_existing_key(key): raise ValueError(f"Program with key {key} already exists.") - # If we are collecting metrics, create a new metrics entity for this compiled program - if metrics.is_level_enabled(metrics.MINIMAL): - metrics_source = metrics.set_current_source_key(self._metrics_key_from_pool_key(key)) - metrics_source.metadata |= dict( - name=self.definition_stage.definition.__name__, - backend=self.backend.name, - compiled_program_pool_key=hash(key), - **{ - f"{eve_utils.CaseStyleConverter.convert(key.__name__, 'pascal', 'snake')}s": value - for key, value in argument_descriptors.items() - }, - ) - if arg_specialization_info: arg_types, kwarg_types = arg_specialization_info else: @@ -520,11 +581,18 @@ def _compile_variant( compile_call = functools.partial( self.backend.compile, self.definition_stage, compile_time_args=compile_time_args ) + compile_variant_hook( + self.definition_stage, + self.backend, + offset_provider=offset_provider, + argument_descriptors=argument_descriptors, + key=key, + ) + if _async_compilation_pool is None: - # synchronous compilation self.compiled_programs[key] = compile_call() else: - self.compiled_programs[key] = _async_compilation_pool.submit(compile_call) + self._compilation_jobs[key] = _async_compilation_pool.submit(compile_call) # TODO(tehrengruber): Rework the interface to allow precompilation with compile time # domains and of scans. @@ -559,10 +627,3 @@ def compile( }, offset_provider=offset_provider, ) - - def _resolve_future(self, key: CompiledProgramsKey) -> stages.CompiledProgram: - program = self.compiled_programs[key] - assert isinstance(program, concurrent.futures.Future) - result = program.result() - self.compiled_programs[key] = result - return result diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 6fd710624d..4d2e4866a4 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -9,7 +9,8 @@ from __future__ import annotations import dataclasses -from typing import Any, Generic, Optional, Protocol, TypeVar +from collections.abc import Callable +from typing import Generic, Optional, Protocol, TypeAlias, TypeVar from gt4py.eve import utils from gt4py.next import common @@ -134,10 +135,7 @@ class BuildSystemProject(Protocol[SrcL_co, SettingT_co, TgtL_co]): def build(self) -> None: ... -class CompiledProgram(Protocol): - """Executable python representation of a program.""" - - def __call__(self, *args: Any, **kwargs: Any) -> None: ... +CompiledProgram: TypeAlias = Callable def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 6ebaab04b0..7686ae097c 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -24,7 +24,7 @@ from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon -class CompiledDaceProgram(stages.CompiledProgram): +class CompiledDaceProgram: sdfg_program: dace.CompiledSDFG # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 2e8a90be9c..df8faf405a 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -137,7 +137,7 @@ def fencil_generator( if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") - return typing.cast(stages.CompiledProgram, _FENCIL_CACHE[cache_key]) + return _FENCIL_CACHE[cache_key] # A CompiledProgram is just a Callable ir = transforms(ir, offset_provider=offset_provider) diff --git a/src/gt4py/next/typing.py b/src/gt4py/next/typing.py index 696b32e477..4ada4b4e23 100644 --- a/src/gt4py/next/typing.py +++ b/src/gt4py/next/typing.py @@ -11,16 +11,23 @@ from gt4py._core.definitions import Scalar from gt4py.next import allocators, backend from gt4py.next.common import OffsetProvider -from gt4py.next.ffront import decorator +from gt4py.next.ffront import decorator, stages as ffront_stages +from gt4py.next.otf import compiled_program _ONLY_FOR_TYPING: Final[str] = "only for typing" # TODO(havogt): alternatively we could introduce Protocols -GTEntryPoint: TypeAlias = Annotated[decorator.GTEntryPoint, _ONLY_FOR_TYPING] +DSLDefinition: TypeAlias = Annotated[ffront_stages.DSLDefinition, _ONLY_FOR_TYPING] + Program: TypeAlias = Annotated[decorator.Program, _ONLY_FOR_TYPING] FieldOperator: TypeAlias = Annotated[decorator.FieldOperator, _ONLY_FOR_TYPING] +GTEntryPoint: TypeAlias = Annotated[decorator.GTEntryPoint, _ONLY_FOR_TYPING] + +CompiledProgramsKey: TypeAlias = Annotated[compiled_program.CompiledProgramsKey, _ONLY_FOR_TYPING] + Backend: TypeAlias = Annotated[backend.Backend, _ONLY_FOR_TYPING] + FieldBufferAllocationUtil: TypeAlias = Annotated[ allocators.FieldBufferAllocationUtil, _ONLY_FOR_TYPING ] diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py new file mode 100644 index 0000000000..cf7abc3464 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/instrumentation_tests/test_hooks.py @@ -0,0 +1,247 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import contextlib +from collections.abc import Callable +from typing import Any + +import pytest + +import gt4py.next as gtx +from gt4py.next import common, Dims, gtfn_cpu, typing as gtx_typing +from gt4py.next.instrumentation import hooks + +try: + from gt4py.next.program_processors.runners import dace as dace_backends + + BACKENDS = [None, gtfn_cpu, dace_backends.run_dace_cpu_cached] +except ImportError: + BACKENDS = [None, gtfn_cpu] + + +IDim = gtx.Dimension("IDim") + + +@gtx.field_operator +def fop(cond: bool, a: gtx.Field[gtx.Dims[IDim], float], b: gtx.Field[gtx.Dims[IDim], float]): + return a if cond else b + + +@gtx.program +def prog( + cond: bool, + a: gtx.Field[gtx.Dims[IDim], gtx.float64], + b: gtx.Field[gtx.Dims[IDim], gtx.float64], + out: gtx.Field[gtx.Dims[IDim], gtx.float64], +): + fop(cond, a, b, out=out) + + +callback_results = [] +embedded_callback_results = [] +compiled_callback_results = [] + + +@contextlib.contextmanager +def custom_program_callback( + program: gtx_typing.Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + enable_jit: bool, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + callback_results.append(("enter-program", None)) + + yield + + callback_results.append( + ( + "custom_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "enable_jit": enable_jit, + "kwargs": kwargs.keys(), + }, + ) + ) + + +@contextlib.contextmanager +def custom_embedded_program_callback( + program: gtx_typing.Program, + args: tuple[Any, ...], + offset_provider: common.OffsetProvider, + kwargs: dict[str, Any], +) -> contextlib.AbstractContextManager: + embedded_callback_results.append(("enter-embedded-program", None)) + + yield + + embedded_callback_results.append( + ( + "custom_embedded_program_callback", + { + "program": program.__name__, + "args": args, + "offset_provider": offset_provider.keys(), + "kwargs": kwargs.keys(), + }, + ) + ) + + +@contextlib.contextmanager +def custom_compiled_program_callback( + compiled_program: Callable, + args: tuple[Any, ...], + kwargs: dict[str, Any], + offset_provider: common.OffsetProvider, + root: tuple[str, str], + key: gtx_typing.CompiledProgramsKey, +) -> contextlib.AbstractContextManager: + compiled_callback_results.append(("enter-compiled-program", None)) + + yield + + compiled_callback_results.append( + ( + "custom_compiled_program_callback", + { + "program": compiled_program, + "args": args, + "kwargs": kwargs, + "offset_provider": offset_provider.keys(), + "root": root, + "key": key, + }, + ) + ) + + +@pytest.mark.parametrize("backend", BACKENDS, ids=lambda b: getattr(b, "name", str(b))) +def test_program_call_hooks(backend: gtx_typing.Backend): + size = 10 + a_field = gtx.full([(IDim, size)], 1, dtype=gtx.float64) + b_field = gtx.full([(IDim, size)], 1, dtype=gtx.float64) + out_field = gtx.empty([(IDim, size)], dtype=gtx.float64) + + test_program = prog.with_backend(backend) + + # Run the program without hooks + callback_results.clear() + embedded_callback_results.clear() + test_program(True, a_field, b_field, out=out_field) + + # Callbacks should not have been called + assert callback_results == [] + callback_results.clear() + assert embedded_callback_results == [] + embedded_callback_results.clear() + assert compiled_callback_results == [] + compiled_callback_results.clear() + + # Add hooks and run the program again + hooks.program_call_context.register(custom_program_callback) + hooks.embedded_program_call_context.register(custom_embedded_program_callback) + hooks.compiled_program_call_context.register(custom_compiled_program_callback) + test_program(True, a_field, b_field, out=out_field) + + # Check that the callbacks were called + assert len(callback_results) == 2 + assert callback_results[0] == ("enter-program", None) + + hook_name, hook_call_info = callback_results[1] + assert hook_name == "custom_program_callback" + assert hook_call_info["program"] == test_program.__name__ + + if backend is None: + # The embedded program call hook should have also been called + # with the embedded backend + assert len(embedded_callback_results) == 2 + assert embedded_callback_results[0] == ("enter-embedded-program", None) + + hook_name, hook_call_info = embedded_callback_results[1] + assert hook_name == "custom_embedded_program_callback" + assert hook_call_info["program"] == prog.__name__ + + assert len(compiled_callback_results) == 0 + + else: + # The compiled program call hook should have also been called + # with the compiled backends + assert len(compiled_callback_results) == 2 + assert compiled_callback_results[0] == ("enter-compiled-program", None) + + hook_name, hook_call_info = compiled_callback_results[1] + assert hook_name == "custom_compiled_program_callback" + assert ( + hook_call_info["program"] + == test_program._compiled_programs.compiled_programs[hook_call_info["key"]] + ) + + assert len(embedded_callback_results) == 0 + + callback_results.clear() + embedded_callback_results.clear() + compiled_callback_results.clear() + + # Remove hooks and call the program again + hooks.program_call_context.remove(custom_program_callback) + hooks.embedded_program_call_context.remove(custom_embedded_program_callback) + hooks.compiled_program_call_context.remove(custom_compiled_program_callback) + test_program(True, a_field, b_field, out=out_field) + + # Callbacks should not have been called + assert callback_results == [] + callback_results.clear() + assert embedded_callback_results == [] + embedded_callback_results.clear() + assert compiled_callback_results == [] + compiled_callback_results.clear() + + +@pytest.mark.parametrize( + "backend", [b for b in BACKENDS if b is not None], ids=lambda b: getattr(b, "name", str(b)) +) +def test_compile_variant_hook(backend: gtx_typing.Backend): + def custom_compile_variant_hook( + program_definition: gtx_typing.DSLDefinition, + backend: gtx_typing.Backend, + offset_provider: common.OffsetProviderType | common.OffsetProvider, + argument_descriptors: dict[type, dict[str, Any]], + key: gtx_typing.CompiledProgramsKey, + ) -> None: + callback_results.append( + ( + "custom_compile_variant_hook", + { + "program_definition": program_definition, + "backend": backend.name, + "argument_descriptors": { + k.__name__: [*v.keys()] for k, v in argument_descriptors.items() + }, + "key": key, + }, + ) + ) + + callback_results.clear() + hooks.compile_variant_hook.register(custom_compile_variant_hook) + testee = prog.with_backend(backend).compile(cond=[True], offset_provider={}) + hooks.compile_variant_hook.remove(custom_compile_variant_hook) + + assert len(callback_results) == 1, f"{callback_results=}" + hook_name, hook_call_info = callback_results[0] + assert hook_name == "custom_compile_variant_hook" + assert hook_call_info["program_definition"] == prog.definition_stage + assert hook_call_info["backend"] == backend.name + assert hook_call_info["argument_descriptors"] == {"StaticArg": ["cond"]} diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py new file mode 100644 index 0000000000..73a7f2f646 --- /dev/null +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_hook_machinery.py @@ -0,0 +1,314 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import contextlib +import dataclasses + +import pytest + +from gt4py.next.instrumentation.hook_machinery import ( + EventHook, + ContextHook, + _get_unique_name, + _is_empty_function, +) + + +def test_get_unique_name(): + def func1(): + pass + + def func2(): + pass + + assert _get_unique_name(func1) != _get_unique_name(func2) + + class A: + def __call__(self): ... + + assert _get_unique_name(A) == _get_unique_name(A) + + a1, a2 = A(), A() + + assert (a1_name := _get_unique_name(a1)) != (a2_name := _get_unique_name(a2)) + assert _get_unique_name(a1) == a1_name + assert _get_unique_name(a2) == a2_name + + +def test_empty_function(): + def empty(): + pass + + assert _is_empty_function(empty) is True + + def non_empty(): + return 1 + + assert _is_empty_function(non_empty) is False + + def with_docstring(): + """This is a docstring.""" + + assert _is_empty_function(with_docstring) is True + + def with_ellipsis(): ... + + assert _is_empty_function(with_ellipsis) is True + + class A: + def __call__(self): ... + + assert _is_empty_function(A()) is True + + +class TestEventHook: + def test_event_hook_call_with_no_callbacks(self): + @EventHook + def hook(x: int) -> None: + pass + + hook(42) # Should not raise + + def test_event_hook_call_with_callbacks(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback1(x: int) -> None: + results.append(x) + + def callback2(x: int) -> None: + results.append(x * 2) + + hook.register(callback1) + hook.register(callback2) + hook(5) + + assert results == [5, 10] + + def test_event_hook_register_with_signature_mismatch(self): + @EventHook + def hook(x: int) -> None: + pass + + def bad_callback(x: int, y: int) -> None: + pass + + with pytest.raises(ValueError, match="Callback signature"): + hook.register(bad_callback) + + def test_event_hook_register_with_annotation_mismatch(self): + @EventHook + def hook(x: int) -> None: + pass + + def weird_callback(x: str) -> None: + pass + + with pytest.warns(UserWarning, match="Callback annotations"): + hook.register(weird_callback) + + def test_event_hook_register_with_name(self): + @EventHook + def hook(x: int) -> None: + pass + + def callback(x: int) -> None: + pass + + hook.register(callback, name="my_callback") + + assert "my_callback" in hook.registry + + def test_event_hook_register_with_index(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback1(x: int) -> None: + results.append(1) + + def callback2(x: int) -> None: + results.append(2) + + hook.register(callback1) + hook.register(callback2, index=0) + hook(0) + + assert results == [2, 1] + + def test_event_hook_remove_by_name(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback(x: int) -> None: + results.append(x) + + hook.register(callback, name="test_cb") + hook(42) + assert results == [42] + + hook.remove("test_cb") + results = [] + hook(42) + + assert results == [] + + def test_event_hook_remove_by_callback(self): + results = [] + + @EventHook + def hook(x: int) -> None: + pass + + def callback(x: int) -> None: + results.append(x) + + hook.register(callback) + hook(42) + assert results == [42] + + hook.remove(callback) + results = [] + hook(42) + + assert results == [] + + def test_event_hook_remove_nonexistent_raises(self): + @EventHook + def hook(x: int) -> None: + pass + + with pytest.raises(KeyError): + hook.remove("nonexistent") + + +class TestContextHook: + def test_context_hook_basic(self): + enter_called = [] + exit_called = [] + + @ContextHook + def hook() -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback(): + enter_called.append(True) + yield + exit_called.append(True) + + hook.register(callback) + + with hook(): + assert len(enter_called) == 1 + + assert len(exit_called) == 1 + + def test_context_hook_multiple_callbacks(self): + order = [] + + @ContextHook + def hook() -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback1(): + order.append("enter1") + yield + order.append("exit1") + + @contextlib.contextmanager + def callback2(): + order.append("enter2") + yield + order.append("exit2") + + hook.register(callback1) + hook.register(callback2) + + with hook(): + pass + + # Entry in order, but exit in reverse + assert order == ["enter1", "enter2", "exit2", "exit1"] + + def test_context_hook_with_arguments(self): + results = [] + + @ContextHook + def hook(x: int) -> contextlib.AbstractContextManager: + pass + + @contextlib.contextmanager + def callback(x: int): + results.append(x) + yield + + hook.register(callback) + + with hook(42): + pass + + assert results == [42] + + +def test_context_hook_callback_partial(): + exit_called = [] + + @dataclasses.dataclass(slots=True) + class MyContextCallback: + def __enter__(self): + pass + + def __exit__(self, type_, exc_value, traceback): + exit_called.append(True) + + assert dataclasses.is_dataclass(MyContextCallback) is True + assert len(MyContextCallback.__dataclass_fields__.keys()) == 0 + + with MyContextCallback(): + assert len(exit_called) == 0 + + assert len(exit_called) == 1 + assert exit_called[0] == True + + +def test_context_hook_callback_full(): + enter_called = [] + exit_called = [] + + @dataclasses.dataclass(slots=True) + class MyContextCallback: + enter_value: int + exit_value: int + + def __enter__(self): + enter_called.append(self.enter_value) + + def __exit__(self, type_, exc_value, traceback): + exit_called.append(self.exit_value) + + assert dataclasses.is_dataclass(MyContextCallback) is True + assert MyContextCallback.__dataclass_fields__.keys() == {"enter_value", "exit_value"} + + with MyContextCallback(42, 43): + assert len(enter_called) == 1 + assert enter_called[0] == 42 + + assert len(exit_called) == 1 + assert exit_called[0] == 43 diff --git a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py index fdb610ba38..7aee5580f7 100644 --- a/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py +++ b/tests/next_tests/unit_tests/instrumentation_tests/test_metrics.py @@ -26,11 +26,9 @@ def test_set_current_source_key_basic(self): metrics._source_key_cvar.set(metrics._NO_KEY_SET_MARKER_) key = "test_source" - source = metrics.set_current_source_key(key) + metrics.set_current_source_key(key) assert metrics.get_current_source_key() == key - assert metrics.sources[key] == source - assert isinstance(source, metrics.Source) def test_set_current_source_key_same_key_twice(self): """Test setting the same source key twice."""