Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions src/gt4py/next/factory_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import functools
from typing import Any, Callable, overload

import factory


class Factory(factory.Factory):
"""`
Factory that defaults all ``factory.Trait`` params to ``False``.

Ensures that every ``factory.Trait`` declared in ``Params`` is present in
the keyword arguments passed to ``create``, defaulting to ``False`` when
not explicitly provided. This allows trait-dependent declarations to
always access the trait flag, even when the caller does not mention it.
"""

@classmethod
def create(cls, **kwargs: Any) -> Any:
# adjust keyword arguments so that traits options are available even when not given
# explicitly
for name, param in cls._meta.parameters.items():
if isinstance(param, factory.Trait):
kwargs.setdefault(name, False)
return super().create(**kwargs)


class DynamicTransformer(factory.declarations.BaseDeclaration):
CAPTURE_OVERRIDES = True
UNROLL_CONTEXT_BEFORE_EVALUATION = False

def __init__(self, default: Any, *, transform: Callable[[Any, Any], Any]) -> None:
super().__init__()
self.default = default
self.transform = transform

def evaluate_pre(self, instance: Any, step: Any, overrides: dict[str, Any]) -> Any:
# The call-time value, if present, is set under the "" key.
value_or_declaration = overrides.pop("", self.default)

if isinstance(value_or_declaration, factory.Transformer.Force):
bypass_transform = True
value_or_declaration = value_or_declaration.forced_value
else:
bypass_transform = False

value = self._unwrap_evaluate_pre(
value_or_declaration,
instance=instance,
step=step,
overrides=overrides,
)
if bypass_transform:
return value

transform = self._unwrap_evaluate_pre(
self.transform,
instance=instance,
step=step,
overrides=overrides,
)

return transform(instance, value)


@overload
def dynamic_transformer(func: Callable[[Any, Any], Any], *, default: Any) -> DynamicTransformer: ...


@overload
def dynamic_transformer(
*, default: Any
) -> Callable[[Callable[[Any, Any], Any]], DynamicTransformer]: ...


def dynamic_transformer(
func: Callable[[Any, Any], Any] | None = None, *, default: Any
) -> DynamicTransformer | Callable[[Callable[[Any, Any], Any]], DynamicTransformer]:
"""
Decorator that creates a factory field whose value is always passed through a transform.

Works like ``factory.Transformer`` but the transform function receives the
full factory instance, so it can read other parameters/traits. The
*default* argument provides the base value (may be a factory declaration
such as ``factory.SelfAttribute``). Use ``factory.Transformer.Force`` to
bypass the transform at call-time.

Example:
>>> import dataclasses, factory
>>> @dataclasses.dataclass
... class Person:
... name: str
... nickname: str
>>> class PersonFactory(factory.Factory):
... class Meta:
... model = Person
...
... name = "Joe"
...
... @dynamic_transformer(default=factory.SelfAttribute(".name"))
... def nickname(self, nickname):
... return f"{nickname}y"
>>> PersonFactory().nickname
'Joey'
>>> PersonFactory(name="John").nickname
'Johny'
>>> PersonFactory(name=factory.Transformer.Force("John")).name
'John'
"""
if func is None:
return functools.partial(dynamic_transformer, default=default)
return DynamicTransformer(default=default, transform=func)
164 changes: 73 additions & 91 deletions src/gt4py/next/program_processors/runners/dace/workflow/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,75 +9,22 @@
from __future__ import annotations

import warnings
from typing import Any, Final
from typing import Callable, Final

import factory

import gt4py.next.allocators as next_allocators
from gt4py._core import definitions as core_defs
from gt4py.next import backend, common, config
from gt4py.next import backend, common, config, factory_utils
from gt4py.next.otf import stages, workflow
from gt4py.next.program_processors.runners.dace.workflow.factory import DaCeWorkflowFactory


class DaCeBackendFactory(factory.Factory):
class DaCeBackendFactory(factory_utils.Factory):
"""
Workflow factory for the GTIR-DaCe backend.
Configurable factory for DaCe backend.

Several parameters are inherithed from `backend.Backend`, see below the specific ones.

Args:
auto_optimize: Enables the SDFG transformation pipeline.
"""

class Meta:
model = backend.Backend

class Params:
name_device = "cpu"
name_cached = ""
name_postfix = ""
gpu = factory.Trait(
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA,
name_device="gpu",
)
cached = factory.Trait(
executor=factory.LazyAttribute(
lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function)
),
name_cached="_cached",
)
device_type = core_defs.DeviceType.CPU
hash_function = stages.compilation_hash
otf_workflow = factory.SubFactory(
DaCeWorkflowFactory,
device_type=factory.SelfAttribute("..device_type"),
auto_optimize=factory.SelfAttribute("..auto_optimize"),
)
auto_optimize = factory.Trait(name_postfix="_opt")

name = factory.LazyAttribute(
lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}"
)

executor = factory.LazyAttribute(lambda o: o.otf_workflow)
allocator = next_allocators.StandardCPUFieldBufferAllocator()
transforms = backend.DEFAULT_TRANSFORMS


def make_dace_backend(
gpu: bool,
cached: bool = True,
auto_optimize: bool = True,
async_sdfg_call: bool = True,
optimization_args: dict[str, Any] | None = None,
use_metrics: bool = True,
use_zero_origin: bool = False,
) -> backend.Backend:
"""Customize the dace backend with the given configuration parameters.

Args:
Parameters:
gpu: Enable GPU transformations and code generation.
cached: Cache the lowered SDFG as a JSON file and the compiled programs.
auto_optimize: Enable the SDFG auto-optimize pipeline.
Expand All @@ -99,72 +46,107 @@ def make_dace_backend(
A dace backend with custom configuration for the target device.
"""

# The `gt_optimization_args` set contains the parameters of `gt_auto_optimize()`
# that are derived from the gt4py configuration, and therefore cannot be customized.
gt_optimization_args: Final[set[str]] = {"gpu", "constant_symbols", "unit_strides_kind"}

if optimization_args is None:
optimization_args = {}
elif optimization_args and not auto_optimize:
warnings.warn("Optimizations args given, but auto-optimize is disabled.", stacklevel=2)
elif intersect_args := gt_optimization_args.intersection(optimization_args.keys()):
raise ValueError(
f"The following optimization arguments cannot be overriden: {intersect_args}."
class Meta:
model = backend.Backend

class Params:
gpu = factory.Trait(
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA,
)
cached: bool = True
auto_optimize: bool = True
async_sdfg_call: bool = factory.SelfAttribute(".gpu") # type: ignore[assignment]
use_metrics: bool = True
use_zero_origin: bool = False
device_type = core_defs.DeviceType.CPU
hash_function = stages.compilation_hash

# Set `unit_strides_kind` based on the gt4py env configuration.
optimization_args = optimization_args | {
"unit_strides_kind": common.DimensionKind.HORIZONTAL
if config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE
else None
}

return DaCeBackendFactory( # type: ignore[return-value] # factory-boy typing not precise enough
gpu=gpu,
cached=cached,
auto_optimize=auto_optimize,
otf_workflow__cached_translation=cached,
otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False),
otf_workflow__bare_translation__auto_optimize_args=optimization_args,
otf_workflow__bare_translation__use_metrics=use_metrics,
otf_workflow__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin,
@factory_utils.dynamic_transformer(default=factory.Dict({}))
def optimization_args(self, optimization_args: dict) -> dict:
# The `gt_optimization_args` set contains the parameters of `gt_auto_optimize()`
# that are derived from the gt4py configuration, and therefore cannot be customized.
gt_optimization_args: Final[set[str]] = {"gpu", "constant_symbols", "unit_strides_kind"}

if optimization_args and not self.auto_optimize:
warnings.warn(
"Optimizations args given, but auto-optimize is disabled.", stacklevel=2
)
elif intersect_args := gt_optimization_args.intersection(optimization_args.keys()):
raise ValueError(
f"The following optimization arguments cannot be overriden: {intersect_args}."
)

# Set `unit_strides_kind` based on the gt4py env configuration.
return optimization_args | {
"unit_strides_kind": common.DimensionKind.HORIZONTAL
if config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE
else None
}

@factory.lazy_attribute
def name(self: factory.builder.Resolver) -> str:
name = "run_dace_"
name += "gpu" if self.gpu else "cpu"
if self.auto_optimize:
name += "_opt"
if self.cached:
name += "_cached"
return name

@factory_utils.dynamic_transformer(
default=factory.SubFactory(
DaCeWorkflowFactory,
device_type=factory.SelfAttribute("..device_type"),
auto_optimize=factory.SelfAttribute("..auto_optimize"),
translation__async_sdfg_call=factory.SelfAttribute("...async_sdfg_call"),
translation__auto_optimize_args=factory.SelfAttribute("...optimization_args"),
translation__use_metrics=factory.SelfAttribute("...use_metrics"),
translation__disable_field_origin_on_program_arguments=factory.SelfAttribute(
"...use_zero_origin"
),
)
)
def executor(self: factory.builder.Resolver, value: Callable) -> Callable:
if self.cached:
return workflow.CachedStep(value, hash_function=self.hash_function)
return value

allocator = next_allocators.StandardCPUFieldBufferAllocator()
transforms = backend.DEFAULT_TRANSFORMS


make_dace_backend = DaCeBackendFactory


run_dace_cpu = make_dace_backend(
gpu=False,
cached=False,
auto_optimize=True,
async_sdfg_call=False,
)
run_dace_cpu_noopt = make_dace_backend(
gpu=False,
cached=False,
auto_optimize=False,
async_sdfg_call=False,
)
run_dace_cpu_cached = make_dace_backend(
gpu=False,
cached=True,
auto_optimize=True,
async_sdfg_call=False,
)

run_dace_gpu = make_dace_backend(
gpu=True,
cached=False,
auto_optimize=True,
async_sdfg_call=True,
)
run_dace_gpu_noopt = make_dace_backend(
gpu=True,
cached=False,
auto_optimize=False,
async_sdfg_call=True,
)
run_dace_gpu_cached = make_dace_backend(
gpu=True,
cached=True,
auto_optimize=True,
async_sdfg_call=True,
)
Loading