diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 90fbdfc505..785cc0b24d 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -4,7 +4,7 @@ import typing from gt4py import next as gtx from gt4py.next.otf import toolchain, workflow -from gt4py.next.ffront import stages as ff_stages +from gt4py.next.ffront import field_operator_ast as foast, stages as ff_stages from gt4py import eve ``` @@ -22,8 +22,8 @@ cached_lowering_toolchain = gtx.backend.DEFAULT_TRANSFORMS.replace( ## Skip Steps / Change Order ```python -DUMMY_FOP = toolchain.CompilableProgram( - data=ff_stages.FieldOperatorDefinition(definition=None), args=None +DUMMY_FOP = toolchain.ConcreteArtifact( + data=ff_stages.DSLFieldOperatorDef(definition=None), args=None ) ``` @@ -57,9 +57,9 @@ class Cpp2BindingsGen: ... class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWorkflowFactory): translation: workflow.Workflow[ - gtx.otf.stages.CompilableProgram, gtx.otf.stages.ProgramSource + gtx.otf.definitions.CompilableProgramDef, gtx.otf.stages.ProgramSource ] = MyCodeGen() - bindings: workflow.Workflow[gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableSource] = ( + bindings: workflow.Workflow[gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableProject] = ( Cpp2BindingsGen() ) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 33fe77c4a3..c492db0b34 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -9,7 +9,6 @@ from __future__ import annotations import dataclasses -import typing from typing import Generic from gt4py._core import definitions as core_defs @@ -21,30 +20,34 @@ func_to_past, past_process_args, past_to_itir, + stages as ffront_stages, ) from gt4py.next.ffront.past_passes import linters as past_linters -from gt4py.next.ffront.stages import ( - AOT_DSL_FOP, - AOT_DSL_PRG, - AOT_FOP, - AOT_PRG, - DSL_FOP, - DSL_PRG, - FOP, - PAST_PRG, -) from gt4py.next.iterator import ir as itir -from gt4py.next.otf import arguments, stages, toolchain, workflow +from gt4py.next.otf import arguments, definitions, stages, toolchain, workflow + +def jit_to_aot_args( + inp: arguments.JITArgs, +) -> arguments.CompileTimeArgs: + return arguments.CompileTimeArgs.from_concrete(*inp.args, **inp.kwargs) -IRDefinitionForm: typing.TypeAlias = DSL_FOP | FOP | DSL_PRG | PAST_PRG | itir.Program -CompilableDefinition: typing.TypeAlias = toolchain.CompilableProgram[ - IRDefinitionForm, arguments.JITArgs | arguments.CompileTimeArgs -] + +def adapted_jit_to_aot_args_factory() -> workflow.Workflow[ + definitions.ConcreteProgramDef[definitions.IRDefinitionT, arguments.JITArgs], + definitions.ConcreteProgramDef[definitions.IRDefinitionT, arguments.CompileTimeArgs], +]: + """Wrap `jit_to_aot` into a workflow adapter to fit into backend transform workflows.""" + return toolchain.ArgsOnlyAdapter(jit_to_aot_args) @dataclasses.dataclass(frozen=True) -class Transforms(workflow.MultiWorkflow[CompilableDefinition, stages.CompilableProgram]): +class Transforms( + workflow.MultiWorkflow[ + definitions.ConcreteProgramDef[definitions.IRDefinitionT, definitions.ArgsDefinitionT], + definitions.CompilableProgramDef, + ] +): """ Modular workflow for transformations with access to intermediates. @@ -60,44 +63,44 @@ class Transforms(workflow.MultiWorkflow[CompilableDefinition, stages.CompilableP """ aotify_args: workflow.Workflow[ - toolchain.CompilableProgram[IRDefinitionForm, arguments.JITArgs], - toolchain.CompilableProgram[IRDefinitionForm, arguments.CompileTimeArgs], - ] = dataclasses.field(default_factory=arguments.adapted_jit_to_aot_args_factory) + definitions.ConcreteProgramDef[definitions.IRDefinitionT, arguments.JITArgs], + definitions.ConcreteProgramDef[definitions.IRDefinitionT, arguments.CompileTimeArgs], + ] = dataclasses.field(default_factory=adapted_jit_to_aot_args_factory) - func_to_foast: workflow.Workflow[AOT_DSL_FOP, AOT_FOP] = dataclasses.field( - default_factory=func_to_foast.adapted_func_to_foast_factory - ) + func_to_foast: workflow.Workflow[ + ffront_stages.ConcreteDSLFieldOperatorDef, ffront_stages.ConcreteFOASTOperatorDef + ] = dataclasses.field(default_factory=func_to_foast.adapted_func_to_foast_factory) - func_to_past: workflow.Workflow[AOT_DSL_PRG, AOT_PRG] = dataclasses.field( - default_factory=func_to_past.adapted_func_to_past_factory - ) + func_to_past: workflow.Workflow[ + ffront_stages.ConcreteDSLProgramDef, ffront_stages.ConcretePASTProgramDef + ] = dataclasses.field(default_factory=func_to_past.adapted_func_to_past_factory) - foast_to_itir: workflow.Workflow[AOT_FOP, itir.FunctionDefinition] = dataclasses.field( - default_factory=foast_to_gtir.adapted_foast_to_gtir_factory - ) + foast_to_itir: workflow.Workflow[ + ffront_stages.ConcreteFOASTOperatorDef, itir.FunctionDefinition + ] = dataclasses.field(default_factory=foast_to_gtir.adapted_foast_to_gtir_factory) - field_view_op_to_prog: workflow.Workflow[AOT_FOP, AOT_PRG] = dataclasses.field( - default_factory=foast_to_past.operator_to_program_factory - ) + field_view_op_to_prog: workflow.Workflow[ + ffront_stages.ConcreteFOASTOperatorDef, ffront_stages.ConcretePASTProgramDef + ] = dataclasses.field(default_factory=foast_to_past.operator_to_program_factory) - past_lint: workflow.Workflow[AOT_PRG, AOT_PRG] = dataclasses.field( - default_factory=past_linters.adapted_linter_factory - ) + past_lint: workflow.Workflow[ + ffront_stages.ConcretePASTProgramDef, ffront_stages.ConcretePASTProgramDef + ] = dataclasses.field(default_factory=past_linters.adapted_linter_factory) - field_view_prog_args_transform: workflow.Workflow[AOT_PRG, AOT_PRG] = dataclasses.field( - default_factory=past_process_args.transform_program_args_factory - ) + field_view_prog_args_transform: workflow.Workflow[ + ffront_stages.ConcretePASTProgramDef, ffront_stages.ConcretePASTProgramDef + ] = dataclasses.field(default_factory=past_process_args.transform_program_args_factory) - past_to_itir: workflow.Workflow[AOT_PRG, stages.CompilableProgram] = dataclasses.field( - default_factory=past_to_itir.past_to_gtir_factory - ) + past_to_itir: workflow.Workflow[ + ffront_stages.ConcretePASTProgramDef, definitions.CompilableProgramDef + ] = dataclasses.field(default_factory=past_to_itir.past_to_gtir_factory) - def step_order(self, inp: CompilableDefinition) -> list[str]: + def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: steps: list[str] = [] if isinstance(inp.args, arguments.JITArgs): steps.append("aotify_args") match inp.data: - case DSL_FOP(): + case ffront_stages.DSLFieldOperatorDef(): steps.extend( [ "func_to_foast", @@ -107,7 +110,7 @@ def step_order(self, inp: CompilableDefinition) -> list[str]: "past_to_itir", ] ) - case FOP(): + case ffront_stages.FOASTOperatorDef(): steps.extend( [ "field_view_op_to_prog", @@ -116,11 +119,16 @@ def step_order(self, inp: CompilableDefinition) -> list[str]: "past_to_itir", ] ) - case DSL_PRG(): + case ffront_stages.DSLProgramDef(): steps.extend( - ["func_to_past", "past_lint", "field_view_prog_args_transform", "past_to_itir"] + [ + "func_to_past", + "past_lint", + "field_view_prog_args_transform", + "past_to_itir", + ] ) - case PAST_PRG(): + case ffront_stages.PASTProgramDef(): steps.extend(["past_lint", "field_view_prog_args_transform", "past_to_itir"]) case itir.Program(): pass @@ -139,15 +147,15 @@ def step_order(self, inp: CompilableDefinition) -> list[str]: @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): name: str - executor: workflow.Workflow[stages.CompilableProgram, stages.CompiledProgram] + executor: workflow.Workflow[definitions.CompilableProgramDef, stages.CompiledProgram] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] - transforms: workflow.Workflow[CompilableDefinition, stages.CompilableProgram] + transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef] def compile( - self, program: IRDefinitionForm, compile_time_args: arguments.CompileTimeArgs + self, program: definitions.IRDefinitionT, compile_time_args: arguments.CompileTimeArgs ) -> stages.CompiledProgram: return self.executor( - self.transforms(toolchain.CompilableProgram(data=program, args=compile_time_args)) + self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args)) ) @property diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a857b4e700..aff0867e9a 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -19,7 +19,7 @@ import typing import warnings from collections.abc import Callable -from typing import Any, Generic, Optional, Sequence, TypeVar +from typing import Any, Generic, Optional, Sequence, TypeAlias from gt4py import eve from gt4py._core import definitions as core_defs @@ -52,20 +52,16 @@ DEFAULT_BACKEND: next_backend.Backend | None = None -ProgramLikeDefinitionT = TypeVar( - "ProgramLikeDefinitionT", ffront_stages.ProgramDefinition, ffront_stages.FieldOperatorDefinition -) - @dataclasses.dataclass(frozen=True) -class _ProgramLikeMixin(Generic[ProgramLikeDefinitionT]): +class _CompilableGTEntryPointMixin(Generic[ffront_stages.DSLDefinitionT]): """ Mixing used by program and program-like objects. Contains functionality and configuration options common to all kinds of program-likes. """ - definition_stage: ProgramLikeDefinitionT + definition_stage: ffront_stages.DSLDefinitionT backend: Optional[next_backend.Backend] compilation_options: options.CompilationOptions @@ -174,7 +170,7 @@ def compile( # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. @dataclasses.dataclass(frozen=True) -class Program(_ProgramLikeMixin[ffront_stages.ProgramDefinition]): +class Program(_CompilableGTEntryPointMixin[ffront_stages.DSLProgramDef]): """ Construct a program object from a PAST node. @@ -202,7 +198,7 @@ def from_function( grid_type: common.GridType | None = None, **compilation_options: Unpack[options.CompilationOptionsArgs], ) -> Program: - program_def = ffront_stages.ProgramDefinition(definition=definition, grid_type=grid_type) + program_def = ffront_stages.DSLProgramDef(definition=definition, grid_type=grid_type) return cls( definition_stage=program_def, backend=backend, @@ -215,7 +211,7 @@ def __gt_type__(self) -> ts_ffront.ProgramType: # TODO(ricoh): linting should become optional, up to the backend. def __post_init__(self) -> None: - no_args_past = toolchain.CompilableProgram( + no_args_past = toolchain.ConcreteArtifact( self.past_stage, arguments.CompileTimeArgs.empty() ) _ = self._frontend_transforms.past_lint(no_args_past).data @@ -238,9 +234,9 @@ def definition(self) -> types.FunctionType: return self.definition_stage.definition @functools.cached_property - def past_stage(self) -> ffront_stages.PAST_PRG: + def past_stage(self) -> ffront_stages.PASTProgramDef: # backwards compatibility for backends that do not support the full toolchain - no_args_def = toolchain.CompilableProgram( + no_args_def = toolchain.ConcreteArtifact( self.definition_stage, arguments.CompileTimeArgs.empty() ) return self._frontend_transforms.func_to_past(no_args_def).data @@ -260,8 +256,8 @@ def _all_closure_vars(self) -> dict[str, Any]: @functools.cached_property def gtir(self) -> itir.Program: - no_args_past = toolchain.CompilableProgram( - data=ffront_stages.PastProgramDefinition( + no_args_past = toolchain.ConcreteArtifact( + data=ffront_stages.PASTProgramDef( past_node=self.past_stage.past_node, closure_vars=self.past_stage.closure_vars, grid_type=self.definition_stage.grid_type, @@ -503,13 +499,8 @@ def program_inner(definition: types.FunctionType) -> Program: return program_inner if definition is None else program_inner(definition) -OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) - - @dataclasses.dataclass(frozen=True) -class FieldOperator( - _ProgramLikeMixin[ffront_stages.FieldOperatorDefinition], GTCallable, Generic[OperatorNodeT] -): +class FieldOperator(_CompilableGTEntryPointMixin[ffront_stages.DSLFieldOperatorDef], GTCallable): """ Construct a field operator object from a FOAST node. @@ -537,12 +528,12 @@ def from_function( backend: Optional[next_backend.Backend], grid_type: Optional[common.GridType] = None, *, - operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, # type: ignore[assignment] # TODO(ricoh): understand why mypy complains + operator_node_cls: type[foast.OperatorNode] = foast.FieldOperator, operator_attributes: Optional[dict[str, Any]] = None, **compilation_options: Unpack[options.CompilationOptionsArgs], - ) -> FieldOperator[OperatorNodeT]: + ) -> FieldOperator: return cls( - definition_stage=ffront_stages.FieldOperatorDefinition( + definition_stage=ffront_stages.DSLFieldOperatorDef( definition=definition, grid_type=grid_type, node_class=operator_node_cls, @@ -558,9 +549,9 @@ def __post_init__(self) -> None: _ = self.foast_stage @functools.cached_property - def foast_stage(self) -> ffront_stages.FoastOperatorDefinition: + def foast_stage(self) -> ffront_stages.FOASTOperatorDef: return self._frontend_transforms.func_to_foast( - toolchain.CompilableProgram( + toolchain.ConcreteArtifact( data=self.definition_stage, args=arguments.CompileTimeArgs.empty() ) ).data @@ -653,6 +644,9 @@ def __call__(self, *args: Any, enable_jit: bool | None = None, **kwargs: Any) -> return embedded_operators.field_operator_call(op, args, kwargs) +GTEntryPoint: TypeAlias = Program | FieldOperator + + # TODO(tehrengruber): This class does not follow the Liskov-Substitution principle as it doesn't # have a field operator definition. Currently implementation is merely a hack to keep the only # test relying on this working. Revisit. @@ -666,7 +660,7 @@ class FieldOperatorFromFoast(FieldOperator): This class provides the appropriate toolchain entry points. """ - foast_stage: ffront_stages.FoastOperatorDefinition + foast_stage: ffront_stages.FOASTOperatorDef @override def __call__(self, *args: Any, **kwargs: Any) -> Any: @@ -683,13 +677,13 @@ def field_operator( *, backend: next_backend.Backend | eve.NothingType | None, grid_type: common.GridType | None, -) -> FieldOperator[foast.FieldOperator]: ... +) -> FieldOperator: ... @typing.overload def field_operator( *, backend: next_backend.Backend | eve.NothingType | None, grid_type: common.GridType | None -) -> Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]]: ... +) -> Callable[[types.FunctionType], FieldOperator]: ... def field_operator( @@ -698,10 +692,7 @@ def field_operator( backend: next_backend.Backend | eve.NothingType | None = eve.NOTHING, grid_type: common.GridType | None = None, **compilation_options: Unpack[options.CompilationOptionsArgs], -) -> ( - FieldOperator[foast.FieldOperator] - | Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]] -): +) -> FieldOperator | Callable[[types.FunctionType], FieldOperator]: """ Generate an implementation of the field operator from a Python function object. @@ -718,7 +709,7 @@ def field_operator( ... ... """ - def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.FieldOperator]: + def field_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, typing.cast( @@ -740,7 +731,7 @@ def scan_operator( init: core_defs.Scalar, backend: next_backend.Backend | eve.NothingType | None, grid_type: common.GridType | None, -) -> FieldOperator[foast.ScanOperator]: ... +) -> FieldOperator: ... @typing.overload @@ -751,7 +742,7 @@ def scan_operator( init: core_defs.Scalar, backend: next_backend.Backend | eve.NothingType | None, grid_type: common.GridType | None, -) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... +) -> Callable[[types.FunctionType], FieldOperator]: ... def scan_operator( @@ -762,10 +753,7 @@ def scan_operator( init: core_defs.Scalar = 0.0, backend: next_backend.Backend | None | eve.NothingType = eve.NOTHING, grid_type: common.GridType | None = None, -) -> ( - FieldOperator[foast.ScanOperator] - | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] -): +) -> FieldOperator | Callable[[types.FunctionType], FieldOperator]: """ Generate an implementation of the scan operator from a Python function object. diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 41d00ec233..fa5bc4889f 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Generic, TypeVar, Union +from typing import Any, Generic, TypeAlias, TypeVar, Union from gt4py import eve from gt4py.eve import ( @@ -235,3 +235,6 @@ class ScanOperator(LocatedNode, SymbolTableTrait): type: Union[ts_ffront.ScanOperatorType, ts.DeferredType] = ts.DeferredType( constraint=ts_ffront.ScanOperatorType ) + + +OperatorNode: TypeAlias = FieldOperator | ScanOperator diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index d74fb5dce8..3825072cb7 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -24,7 +24,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.foast_passes import utils as foast_utils -from gt4py.next.ffront.stages import AOT_FOP, FOP +from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, FOASTOperatorDef from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import constant_folding @@ -32,7 +32,7 @@ from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt -def foast_to_gtir(inp: ffront_stages.FoastOperatorDefinition) -> itir.FunctionDefinition: +def foast_to_gtir(inp: ffront_stages.FOASTOperatorDef) -> itir.FunctionDefinition: """ Lower a FOAST field operator node to GTIR. @@ -41,7 +41,9 @@ def foast_to_gtir(inp: ffront_stages.FoastOperatorDefinition) -> itir.FunctionDe return FieldOperatorLowering.apply(inp.foast_node) -def foast_to_gtir_factory(cached: bool = True) -> workflow.Workflow[FOP, itir.FunctionDefinition]: +def foast_to_gtir_factory( + cached: bool = True, +) -> workflow.Workflow[FOASTOperatorDef, itir.FunctionDefinition]: """Wrap `foast_to_gtir` into a chainable and, optionally, cached workflow step.""" wf = foast_to_gtir if cached: @@ -51,7 +53,7 @@ def foast_to_gtir_factory(cached: bool = True) -> workflow.Workflow[FOP, itir.Fu def adapted_foast_to_gtir_factory( **kwargs: Any, -) -> workflow.Workflow[AOT_FOP, itir.FunctionDefinition]: +) -> workflow.Workflow[ConcreteFOASTOperatorDef, itir.FunctionDefinition]: """Wrap the `foast_to_gtir` workflow step into an adapter to fit into backend transform workflows.""" return toolchain.StripArgsAdapter(foast_to_gtir_factory(**kwargs)) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 5e03e37b8d..05b080b70b 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -18,7 +18,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.past_passes import closure_var_type_deduction, type_deduction -from gt4py.next.ffront.stages import AOT_FOP, AOT_PRG +from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts @@ -33,13 +33,14 @@ class ItirShim: lowering has access to the relevant information. """ - definition: AOT_FOP - foast_to_itir: workflow.Workflow[AOT_FOP, itir.FunctionDefinition] + definition: ConcreteFOASTOperatorDef + foast_to_itir: workflow.Workflow[ConcreteFOASTOperatorDef, itir.FunctionDefinition] def __gt_closure_vars__(self) -> Optional[dict[str, Any]]: return self.definition.data.closure_vars def __gt_type__(self) -> ts.CallableType: + assert isinstance(self.definition.data.foast_node.type, ts.CallableType) return self.definition.data.foast_node.type def __gt_itir__(self) -> itir.FunctionDefinition: @@ -52,7 +53,7 @@ def __gt_gtir__(self) -> itir.FunctionDefinition: @dataclasses.dataclass(frozen=True) -class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): +class OperatorToProgram(workflow.Workflow[ConcreteFOASTOperatorDef, ConcretePASTProgramDef]): """ Generate a PAST program definition from a FOAST operator definition. @@ -82,7 +83,7 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): ... ) >>> copy_program = op_to_prog( - ... toolchain.CompilableProgram(copy.foast_stage, compile_time_args) + ... toolchain.ConcreteArtifact(copy.foast_stage, compile_time_args) ... ) >>> print(copy_program.data.past_node.id) @@ -91,9 +92,9 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): >>> assert copy_program.data.closure_vars["copy"].definition.data is copy.foast_stage """ - foast_to_itir: workflow.Workflow[AOT_FOP, itir.FunctionDefinition] + foast_to_itir: workflow.Workflow[ConcreteFOASTOperatorDef, itir.FunctionDefinition] - def __call__(self, inp: AOT_FOP) -> AOT_PRG: + def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef: # TODO(tehrengruber): implement mechanism to deduce default values # of arg and kwarg types # TODO(tehrengruber): check foast operator has no out argument that clashes @@ -103,6 +104,7 @@ def __call__(self, inp: AOT_FOP) -> AOT_PRG: type_ = inp.data.foast_node.type loc = inp.data.foast_node.location + assert isinstance(inp.data.foast_node.type, ts.CallableType) partial_program_type = ffront_type_info.type_in_program_context(inp.data.foast_node.type) assert isinstance(partial_program_type, ts_ffront.ProgramType) args_names = [ @@ -110,6 +112,7 @@ def __call__(self, inp: AOT_FOP) -> AOT_PRG: *partial_program_type.definition.pos_or_kw_args.keys(), *partial_program_type.definition.kw_only_args.keys(), ] + assert isinstance(type_, ts.CallableType) assert arg_types[-1] == type_info.return_type( type_, with_args=list(arg_types), with_kwargs=kwarg_types ) @@ -161,14 +164,15 @@ def __call__(self, inp: AOT_FOP) -> AOT_PRG: location=loc, ) untyped_past_node = closure_var_type_deduction.ClosureVarTypeDeduction.apply( - untyped_past_node, fieldop_itir_closure_vars + untyped_past_node, + fieldop_itir_closure_vars, # type: ignore[arg-type] ) past_node = type_deduction.ProgramTypeDeduction.apply(untyped_past_node) - return toolchain.CompilableProgram( - data=ffront_stages.PastProgramDefinition( + return toolchain.ConcreteArtifact( + data=ffront_stages.PASTProgramDef( past_node=past_node, - closure_vars=fieldop_itir_closure_vars, + closure_vars=fieldop_itir_closure_vars, # type: ignore[arg-type] grid_type=inp.data.grid_type, ), args=inp.args, @@ -176,11 +180,13 @@ def __call__(self, inp: AOT_FOP) -> AOT_PRG: def operator_to_program_factory( - foast_to_itir_step: Optional[workflow.Workflow[AOT_FOP, itir.FunctionDefinition]] = None, + foast_to_itir_step: Optional[ + workflow.Workflow[ConcreteFOASTOperatorDef, itir.FunctionDefinition] + ] = None, cached: bool = True, -) -> workflow.Workflow[AOT_FOP, AOT_PRG]: +) -> workflow.Workflow[ConcreteFOASTOperatorDef, ConcretePASTProgramDef]: """Optionally wrap `OperatorToProgram` in a `CachedStep`.""" - wf: workflow.Workflow[AOT_FOP, AOT_PRG] = OperatorToProgram( + wf: workflow.Workflow[ConcreteFOASTOperatorDef, ConcretePASTProgramDef] = OperatorToProgram( foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory() ) if cached: diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 8b4e68463c..ced0ff3905 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -35,12 +35,17 @@ from gt4py.next.ffront.foast_passes.dead_closure_var_elimination import DeadClosureVarElimination from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction -from gt4py.next.ffront.stages import AOT_DSL_FOP, AOT_FOP, DSL_FOP, FOP +from gt4py.next.ffront.stages import ( + ConcreteDSLFieldOperatorDef, + ConcreteFOASTOperatorDef, + DSLFieldOperatorDef, + FOASTOperatorDef, +) from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -def func_to_foast(inp: DSL_FOP) -> FOP: +def func_to_foast(inp: DSLFieldOperatorDef) -> FOASTOperatorDef: """ Turn a DSL field operator definition into a FOAST operator definition, adding metadata. @@ -53,7 +58,7 @@ def func_to_foast(inp: DSL_FOP) -> FOP: >>> def dsl_operator(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]: ... return a * const - >>> dsl_operator_def = gtx.ffront.stages.FieldOperatorDefinition(definition=dsl_operator) + >>> dsl_operator_def = gtx.ffront.stages.DSLFieldOperatorDef(definition=dsl_operator) >>> foast_definition = func_to_foast(dsl_operator_def) >>> print(foast_definition.foast_node.id) @@ -79,7 +84,7 @@ def func_to_foast(inp: DSL_FOP) -> FOP: **operator_attribute_nodes, ) foast_node = FieldOperatorTypeDeduction.apply(untyped_foast_node) - return ffront_stages.FoastOperatorDefinition( + return ffront_stages.FOASTOperatorDef( foast_node=foast_node, closure_vars=closure_vars, grid_type=inp.grid_type, @@ -88,7 +93,9 @@ def func_to_foast(inp: DSL_FOP) -> FOP: ) -def func_to_foast_factory(cached: bool = True) -> workflow.Workflow[DSL_FOP, FOP]: +def func_to_foast_factory( + cached: bool = True, +) -> workflow.Workflow[DSLFieldOperatorDef, FOASTOperatorDef]: """Wrap `func_to_foast` in a chainable and optionally cached workflow step.""" wf = workflow.make_step(func_to_foast) if cached: @@ -96,7 +103,9 @@ def func_to_foast_factory(cached: bool = True) -> workflow.Workflow[DSL_FOP, FOP return wf -def adapted_func_to_foast_factory(**kwargs: Any) -> workflow.Workflow[AOT_DSL_FOP, AOT_FOP]: +def adapted_func_to_foast_factory( + **kwargs: Any, +) -> workflow.Workflow[ConcreteDSLFieldOperatorDef, ConcreteFOASTOperatorDef]: """Wrap the `func_to_foast step in an adapter to fit into transform toolchains.`""" return toolchain.DataOnlyAdapter(func_to_foast_factory(**kwargs)) diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 23829db8b5..ebc0de31b3 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -25,12 +25,17 @@ from gt4py.next.ffront.dialect_parser import DialectParser from gt4py.next.ffront.past_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction -from gt4py.next.ffront.stages import AOT_DSL_PRG, AOT_PRG, DSL_PRG, PAST_PRG +from gt4py.next.ffront.stages import ( + ConcreteDSLProgramDef, + ConcretePASTProgramDef, + DSLProgramDef, + PASTProgramDef, +) from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_specifications as ts, type_translation -def func_to_past(inp: DSL_PRG) -> PAST_PRG: +def func_to_past(inp: DSLProgramDef) -> PASTProgramDef: """ Turn a DSL program definition into a PAST Program definition, adding metadata. @@ -46,7 +51,7 @@ def func_to_past(inp: DSL_PRG) -> PAST_PRG: >>> def dsl_program(a: gtx.Field[[IDim], gtx.float32], out: gtx.Field[[IDim], gtx.float32]): ... copy(a, out=out) - >>> dsl_definition = gtx.ffront.stages.ProgramDefinition(definition=dsl_program) + >>> dsl_definition = gtx.ffront.stages.DSLProgramDef(definition=dsl_program) >>> past_definition = func_to_past(dsl_definition) >>> print(past_definition.past_node.id) @@ -57,7 +62,7 @@ def func_to_past(inp: DSL_PRG) -> PAST_PRG: source_def = source_utils.SourceDefinition.from_function(inp.definition) closure_vars = source_utils.get_closure_vars_from_function(inp.definition) annotations = typing.get_type_hints(inp.definition) - return ffront_stages.PastProgramDefinition( + return ffront_stages.PASTProgramDef( past_node=ProgramParser.apply(source_def, closure_vars, annotations), closure_vars=closure_vars, grid_type=inp.grid_type, @@ -65,7 +70,7 @@ def func_to_past(inp: DSL_PRG) -> PAST_PRG: ) -def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PAST_PRG]: +def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSLProgramDef, PASTProgramDef]: """ Wrap `func_to_past` in a chainable and optionally cached workflow step. @@ -79,7 +84,9 @@ def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PAST return wf -def adapted_func_to_past_factory(**kwargs: Any) -> workflow.Workflow[AOT_DSL_PRG, AOT_PRG]: +def adapted_func_to_past_factory( + **kwargs: Any, +) -> workflow.Workflow[ConcreteDSLProgramDef, ConcretePASTProgramDef]: """ Wrap an adapter around the DSL definition -> PAST definition step to fit into transform toolchains. """ diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py index d53b169062..6d9fb9123b 100644 --- a/src/gt4py/next/ffront/past_passes/linters.py +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -9,14 +9,14 @@ from typing import Any from gt4py.next.ffront import gtcallable, stages as ffront_stages, transform_utils -from gt4py.next.ffront.stages import AOT_PRG, PAST_PRG +from gt4py.next.ffront.stages import ConcretePASTProgramDef, PASTProgramDef from gt4py.next.otf import toolchain, workflow @workflow.make_step def lint_misnamed_functions( - inp: ffront_stages.PastProgramDefinition, -) -> ffront_stages.PastProgramDefinition: + inp: ffront_stages.PASTProgramDef, +) -> ffront_stages.PASTProgramDef: function_closure_vars = transform_utils._filter_closure_vars_by_type( inp.closure_vars, gtcallable.GTCallable ) @@ -34,8 +34,8 @@ def lint_misnamed_functions( @workflow.make_step def lint_undefined_symbols( - inp: ffront_stages.PastProgramDefinition, -) -> ffront_stages.PastProgramDefinition: + inp: ffront_stages.PASTProgramDef, +) -> ffront_stages.PASTProgramDef: undefined_symbols = [ symbol.id for symbol in inp.past_node.closure_vars if symbol.id not in inp.closure_vars ] @@ -48,12 +48,14 @@ def lint_undefined_symbols( def linter_factory( cached: bool = True, adapter: bool = True -) -> workflow.Workflow[PAST_PRG, PAST_PRG]: +) -> workflow.Workflow[PASTProgramDef, PASTProgramDef]: wf = lint_misnamed_functions.chain(lint_undefined_symbols) if cached: wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) return wf -def adapted_linter_factory(**kwargs: Any) -> workflow.Workflow[AOT_PRG, AOT_PRG]: +def adapted_linter_factory( + **kwargs: Any, +) -> workflow.Workflow[ConcretePASTProgramDef, ConcretePASTProgramDef]: return toolchain.DataOnlyAdapter(linter_factory(**kwargs)) diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 2c9d3d2770..ce794fd9dc 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Iterator, Sequence, TypeAlias +from typing import Any, Iterator, Sequence from gt4py.next import common, errors from gt4py.next.ffront import ( @@ -18,16 +18,13 @@ from gt4py.next.type_system import type_info, type_specifications as ts -AOT_PRG: TypeAlias = toolchain.CompilableProgram[ - ffront_stages.PastProgramDefinition, arguments.CompileTimeArgs -] - - -def transform_program_args(inp: AOT_PRG) -> AOT_PRG: +def transform_program_args( + inp: ffront_stages.ConcretePASTProgramDef, +) -> ffront_stages.ConcretePASTProgramDef: rewritten_args, rewritten_kwargs = _process_args( past_node=inp.data.past_node, args=inp.args.args, kwargs=inp.args.kwargs ) - return toolchain.CompilableProgram( + return toolchain.ConcreteArtifact( data=inp.data, args=arguments.CompileTimeArgs( args=rewritten_args, @@ -39,7 +36,9 @@ def transform_program_args(inp: AOT_PRG) -> AOT_PRG: ) -def transform_program_args_factory(cached: bool = True) -> workflow.Workflow[AOT_PRG, AOT_PRG]: +def transform_program_args_factory( + cached: bool = True, +) -> workflow.Workflow[ffront_stages.ConcretePASTProgramDef, ffront_stages.ConcretePASTProgramDef]: wf = transform_program_args if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4b76589ae3..4ac013b4fa 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -25,17 +25,17 @@ type_info as ffront_ti, type_specifications as ts_ffront, ) -from gt4py.next.ffront.stages import AOT_PRG +from gt4py.next.ffront.stages import ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import remap_symbols -from gt4py.next.otf import arguments, stages, workflow +from gt4py.next.otf import arguments, definitions, workflow from gt4py.next.type_system import type_info, type_specifications as ts # FIXME[#1582](tehrengruber): This should only depend on the program not the arguments. Remove # dependency as soon as column axis can be deduced from ITIR in consumers of the CompilableProgram. -def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: +def past_to_gtir(inp: ConcretePASTProgramDef) -> definitions.CompilableProgramDef: """ Lower a PAST program definition to Iterator IR. @@ -63,7 +63,7 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ... ) >>> itir_copy = past_to_gtir( - ... toolchain.CompilableProgram(copy_program.past_stage, compile_time_args) + ... toolchain.ConcreteArtifact(copy_program.past_stage, compile_time_args) ... ) >>> print(itir_copy.data.id) @@ -132,12 +132,12 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: if config.DEBUG or inp.data.debug: devtools.debug(itir_program) - return stages.CompilableProgram(data=itir_program, args=compile_time_args) + return definitions.CompilableProgramDef(data=itir_program, args=compile_time_args) def past_to_gtir_factory( cached: bool = True, -) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]: +) -> workflow.Workflow[ConcretePASTProgramDef, definitions.CompilableProgramDef]: wf = workflow.make_step(past_to_gtir) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index ce2e8eeda6..d6dbddd7c0 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -6,15 +6,29 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +""" +Definitions of the stages of the GT4Py frontend. + +Classes in this module contain different forms of field operator and program +definitions, which are used as input or output of the different stages of +the frontend. + +All classes containing a definition of a GT4Py computation in any form use the +`Def` suffix. Definitions containing actual Python functions whose source code +should be interpreted as GT4Py embedded domain-specific language have `DSL` in +their name. Definitions containing definitions as an AST of one the internal GT4Py +dialects contain `AST`. +""" + from __future__ import annotations -import collections +import collections.abc import dataclasses import functools import hashlib import types import typing -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Optional, TypeVar import xxhash @@ -24,56 +38,59 @@ from gt4py.next.otf import arguments, toolchain -OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) - - @dataclasses.dataclass(frozen=True) -class FieldOperatorDefinition(Generic[OperatorNodeT]): +class DSLFieldOperatorDef: definition: types.FunctionType - grid_type: Optional[common.GridType] = None - node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains + node_class: type[foast.OperatorNode] = foast.FieldOperator attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + grid_type: Optional[common.GridType] = None debug: bool = False -DSL_FOP: typing.TypeAlias = FieldOperatorDefinition -AOT_DSL_FOP: typing.TypeAlias = toolchain.CompilableProgram[DSL_FOP, arguments.CompileTimeArgs] +ConcreteDSLFieldOperatorDef: typing.TypeAlias = toolchain.ConcreteArtifact[ + DSLFieldOperatorDef, arguments.CompileTimeArgs +] @dataclasses.dataclass(frozen=True) -class FoastOperatorDefinition(Generic[OperatorNodeT]): - foast_node: OperatorNodeT +class FOASTOperatorDef: + foast_node: foast.OperatorNode closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None attributes: dict[str, Any] = dataclasses.field(default_factory=dict) debug: bool = False -FOP: typing.TypeAlias = FoastOperatorDefinition -AOT_FOP: typing.TypeAlias = toolchain.CompilableProgram[FOP, arguments.CompileTimeArgs] +ConcreteFOASTOperatorDef: typing.TypeAlias = toolchain.ConcreteArtifact[ + FOASTOperatorDef, arguments.CompileTimeArgs +] @dataclasses.dataclass(frozen=True) -class ProgramDefinition: +class DSLProgramDef: definition: types.FunctionType grid_type: Optional[common.GridType] = None debug: bool = False -DSL_PRG: typing.TypeAlias = ProgramDefinition -AOT_DSL_PRG: typing.TypeAlias = toolchain.CompilableProgram[DSL_PRG, arguments.CompileTimeArgs] +ConcreteDSLProgramDef: typing.TypeAlias = toolchain.ConcreteArtifact[ + DSLProgramDef, arguments.CompileTimeArgs +] @dataclasses.dataclass(frozen=True) -class PastProgramDefinition: +class PASTProgramDef: past_node: past.Program closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None debug: bool = False -PAST_PRG: typing.TypeAlias = PastProgramDefinition -AOT_PRG: typing.TypeAlias = toolchain.CompilableProgram[PAST_PRG, arguments.CompileTimeArgs] +ConcretePASTProgramDef: typing.TypeAlias = toolchain.ConcreteArtifact[ + PASTProgramDef, arguments.CompileTimeArgs +] + +DSLDefinitionT = TypeVar("DSLDefinitionT", DSLFieldOperatorDef, DSLProgramDef) def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str: @@ -98,11 +115,11 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No add_content_to_fingerprint.register(t, add_content_to_fingerprint.registry[object]) -@add_content_to_fingerprint.register(FieldOperatorDefinition) -@add_content_to_fingerprint.register(FoastOperatorDefinition) -@add_content_to_fingerprint.register(ProgramDefinition) -@add_content_to_fingerprint.register(PastProgramDefinition) -@add_content_to_fingerprint.register(toolchain.CompilableProgram) +@add_content_to_fingerprint.register(DSLFieldOperatorDef) +@add_content_to_fingerprint.register(FOASTOperatorDef) +@add_content_to_fingerprint.register(DSLProgramDef) +@add_content_to_fingerprint.register(PASTProgramDef) +@add_content_to_fingerprint.register(toolchain.ConcreteArtifact) @add_content_to_fingerprint.register(arguments.CompileTimeArgs) def add_stage_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: add_content_to_fingerprint(obj.__class__, hasher) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 84c83ad34e..c8399ea2ce 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -29,13 +29,9 @@ final, ) from gt4py.next import common, errors, named_collections, utils -from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -DATA_T = TypeVar("DATA_T") - - def _make_dict_expr(exprs: dict[str, str]) -> str: items = str.join(",", (f"'{k}': {v}" for k, v in exprs.items())) return f"{{{items}}}" @@ -168,20 +164,6 @@ def empty(cls) -> Self: return cls(tuple(), {}, {}, None, {}) -def jit_to_aot_args( - inp: JITArgs, -) -> CompileTimeArgs: - return CompileTimeArgs.from_concrete(*inp.args, **inp.kwargs) - - -def adapted_jit_to_aot_args_factory() -> workflow.Workflow[ - toolchain.CompilableProgram[DATA_T, JITArgs], - toolchain.CompilableProgram[DATA_T, CompileTimeArgs], -]: - """Wrap `jit_to_aot` into a workflow adapter to fit into backend transform workflows.""" - return toolchain.ArgsOnlyAdapter(jit_to_aot_args) - - # This is not really accurate, just an approximation NeedsValueExtraction: TypeAlias = MaybeNestedInTuple[named_collections.CustomNamedCollection] diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 041868b00e..15f4b1866c 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -305,5 +305,5 @@ def create_bindings( @workflow.make_step def bind_source( inp: stages.ProgramSource[SrcL, languages.LanguageWithHeaderFilesSettings], -) -> stages.CompilableSource[SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python]: - return stages.CompilableSource(program_source=inp, binding_source=create_bindings(inp)) +) -> stages.CompilableProject[SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python]: + return stages.CompilableProject(program_source=inp, binding_source=create_bindings(inp)) diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index a565198007..c5868fa3e7 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -73,7 +73,7 @@ class CMakeFactory( def __call__( self, - source: stages.CompilableSource[ + source: stages.CompilableProject[ languages.CPP | languages.CUDA | languages.HIP, languages.LanguageWithHeaderFilesSettings, languages.Python, diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index afff250e46..77c5ebc94b 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -47,7 +47,7 @@ class CompiledbFactory( def __call__( self, - source: stages.CompilableSource[ + source: stages.CompilableProject[ SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python ], cache_lifetime: config.BuildCacheLifetime, @@ -274,7 +274,7 @@ def _cc_get_compiledb( cache_lifetime: config.BuildCacheLifetime, ) -> pathlib.Path: cache_path = cache.get_cache_folder( - stages.CompilableSource(prototype_program_source, None), cache_lifetime + stages.CompilableProject(prototype_program_source, None), cache_lifetime ) # In a multi-threaded environment, multiple threads may try to create the compiledb at the same time @@ -311,7 +311,7 @@ def _cc_create_compiledb( cmake_build_type=build_type, cmake_extra_flags=cmake_flags, )( - stages.CompilableSource( + stages.CompilableProject( prototype_program_source, stages.BindingSource(source_code="", library_deps=()) ), cache_lifetime, diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index 43ceb71fc3..b9d06a1e26 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -50,7 +50,7 @@ def _cache_folder_name(source: stages.ProgramSource) -> str: def get_cache_folder( - compilable_source: stages.CompilableSource, lifetime: config.BuildCacheLifetime + compilable_source: stages.CompilableProject, lifetime: config.BuildCacheLifetime ) -> pathlib.Path: """ Construct the path to where the build system project artifact of a compilable source should be cached. diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index e03fa84e50..ba28584242 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -16,9 +16,9 @@ from gt4py._core import locking from gt4py.next import config -from gt4py.next.otf import languages, stages, step_types, workflow +from gt4py.next.otf import definitions, languages, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer -from gt4py.next.otf.step_types import LS, SrcL, TgtL +from gt4py.next.otf.definitions import LS, SrcL, TgtL SourceLanguageType = TypeVar("SourceLanguageType", bound=languages.NanobindSrcL) @@ -37,7 +37,7 @@ def module_exists(data: build_data.BuildData, src_dir: pathlib.Path) -> bool: class BuildSystemProjectGenerator(Protocol[SrcL, LS, TgtL]): def __call__( self, - source: stages.CompilableSource[SrcL, LS, TgtL], + source: stages.CompilableProject[SrcL, LS, TgtL], cache_lifetime: config.BuildCacheLifetime, ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... @@ -45,14 +45,14 @@ def __call__( @dataclasses.dataclass(frozen=True) class Compiler( workflow.ChainableWorkflowMixin[ - stages.CompilableSource[SourceLanguageType, LanguageSettingsType, languages.Python], + stages.CompilableProject[SourceLanguageType, LanguageSettingsType, languages.Python], stages.CompiledProgram, ], workflow.ReplaceEnabledWorkflowMixin[ - stages.CompilableSource[SourceLanguageType, LanguageSettingsType, languages.Python], + stages.CompilableProject[SourceLanguageType, LanguageSettingsType, languages.Python], stages.CompiledProgram, ], - step_types.CompilationStep[SourceLanguageType, LanguageSettingsType, languages.Python], + definitions.CompilationStep[SourceLanguageType, LanguageSettingsType, languages.Python], ): """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" @@ -64,7 +64,7 @@ class Compiler( def __call__( self, - inp: stages.CompilableSource[SourceLanguageType, LanguageSettingsType, languages.Python], + inp: stages.CompilableProject[SourceLanguageType, LanguageSettingsType, languages.Python], ) -> stages.CompiledProgram: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 5e603b52ed..bbb10c4610 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -14,7 +14,7 @@ import itertools import warnings from collections.abc import Callable, Hashable, Sequence -from typing import Any, TypeAlias, TypeVar +from typing import Any, Generic, TypeAlias, TypeVar from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping, utils as eve_utils @@ -206,7 +206,7 @@ def _validate_argument_descriptors( @dataclasses.dataclass -class CompiledProgramsPool: +class CompiledProgramsPool(Generic[ffront_stages.DSLDefinitionT]): """ A pool of compiled programs for a given program and backend. @@ -222,7 +222,7 @@ class CompiledProgramsPool: """ backend: gtx_backend.Backend - definition_stage: ffront_stages.ProgramDefinition | ffront_stages.FieldOperatorDefinition + definition_stage: ffront_stages.DSLDefinitionT # Note: This type can be incomplete, i.e. contain DeferredType, whenever the operator is a # scan operator. In the future it could also be the type of a generic program. program_type: ts_ffront.ProgramType diff --git a/src/gt4py/next/otf/step_types.py b/src/gt4py/next/otf/definitions.py similarity index 58% rename from src/gt4py/next/otf/step_types.py rename to src/gt4py/next/otf/definitions.py index f2964362ed..255c9d38a8 100644 --- a/src/gt4py/next/otf/step_types.py +++ b/src/gt4py/next/otf/definitions.py @@ -8,9 +8,11 @@ from __future__ import annotations -from typing import Protocol, TypeVar +from typing import Protocol, TypeAlias, TypeVar -from gt4py.next.otf import languages, stages, workflow +from gt4py.next.ffront import stages as ffront_stages +from gt4py.next.iterator import ir as itir +from gt4py.next.otf import arguments, languages, stages, toolchain, workflow SrcL = TypeVar("SrcL", bound=languages.LanguageTag) @@ -21,8 +23,22 @@ LS_co = TypeVar("LS_co", bound=languages.LanguageSettings, covariant=True) +IRDefinitionT = TypeVar( + "IRDefinitionT", + ffront_stages.DSLFieldOperatorDef, + ffront_stages.DSLProgramDef, + ffront_stages.FOASTOperatorDef, + ffront_stages.PASTProgramDef, + itir.Program, +) +ArgsDefinitionT = TypeVar("ArgsDefinitionT", arguments.JITArgs, arguments.CompileTimeArgs) + +ConcreteProgramDef: TypeAlias = toolchain.ConcreteArtifact[IRDefinitionT, ArgsDefinitionT] +CompilableProgramDef: TypeAlias = ConcreteProgramDef[itir.Program, arguments.CompileTimeArgs] + + class TranslationStep( - workflow.ReplaceEnabledWorkflowMixin[stages.CompilableProgram, stages.ProgramSource[SrcL, LS]], + workflow.ReplaceEnabledWorkflowMixin[CompilableProgramDef, stages.ProgramSource[SrcL, LS]], Protocol[SrcL, LS], ): """Translate a GT4Py program to source code (ProgramCall -> ProgramSource).""" @@ -40,15 +56,15 @@ class BindingStep(Protocol[SrcL, LS, TgtL]): def __call__( self, program_source: stages.ProgramSource[SrcL, LS] - ) -> stages.CompilableSource[SrcL, LS, TgtL]: ... + ) -> stages.CompilableProject[SrcL, LS, TgtL]: ... class CompilationStep( - workflow.Workflow[stages.CompilableSource[SrcL, LS, TgtL], stages.CompiledProgram], + workflow.Workflow[stages.CompilableProject[SrcL, LS, TgtL], stages.CompiledProgram], Protocol[SrcL, LS, TgtL], ): """Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram).""" def __call__( - self, source: stages.CompilableSource[SrcL, LS, TgtL] + self, source: stages.CompilableProject[SrcL, LS, TgtL] ) -> stages.CompiledProgram: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 6c14f05c6e..67b97077cd 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -10,14 +10,14 @@ import dataclasses -from gt4py.next.otf import stages, step_types, workflow +from gt4py.next.otf import definitions, stages, workflow @dataclasses.dataclass(frozen=True) class OTFCompileWorkflow(workflow.NamedStepSequence): """The typical compiled backend steps composed into a workflow.""" - translation: step_types.TranslationStep - bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] - compilation: workflow.Workflow[stages.CompilableSource, stages.CompiledProgram] + translation: definitions.TranslationStep + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] + compilation: workflow.Workflow[stages.CompilableProject, stages.CompiledProgram] decoration: workflow.Workflow[stages.CompiledProgram, stages.CompiledProgram] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index f7a5b60ba9..6fd710624d 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -9,12 +9,12 @@ from __future__ import annotations import dataclasses -from typing import Any, Generic, Optional, Protocol, TypeAlias, TypeVar +from typing import Any, Generic, Optional, Protocol, TypeVar from gt4py.eve import utils from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.otf import arguments, languages, toolchain +from gt4py.next.otf import definitions, languages from gt4py.next.otf.binding import interface @@ -28,32 +28,29 @@ SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) -CompilableProgram: TypeAlias = toolchain.CompilableProgram[itir.Program, arguments.CompileTimeArgs] - - -def compilation_hash(otf_closure: CompilableProgram) -> int: +def compilation_hash(program_def: definitions.CompilableProgramDef) -> int: """Given closure compute a hash uniquely determining if we need to recompile.""" - offset_provider = otf_closure.args.offset_provider + offset_provider = program_def.args.offset_provider return hash( ( - otf_closure.data, + program_def.data, # As the frontend types contain lists they are not hashable. As a workaround we just # use content_hash here. - utils.content_hash(tuple(arg for arg in otf_closure.args.args)), + utils.content_hash(tuple(arg for arg in program_def.args.args)), common.hash_offset_provider_items_by_id(offset_provider) if offset_provider else None, - otf_closure.args.column_axis, + program_def.args.column_axis, ) ) -def fingerprint_compilable_program(inp: CompilableProgram) -> str: +def fingerprint_compilable_program(program_def: definitions.CompilableProgramDef) -> str: """ Generates a unique hash string for a stencil source program representing the program, sorted offset_provider, and column_axis. """ - program: itir.Program = inp.data - offset_provider: common.OffsetProvider = inp.args.offset_provider - column_axis: Optional[common.Dimension] = inp.args.column_axis + program: itir.Program = program_def.data + offset_provider: common.OffsetProvider = program_def.args.offset_provider + column_axis: Optional[common.Dimension] = program_def.args.column_axis program_hash = utils.content_hash( ( @@ -107,7 +104,7 @@ class BindingSource(Generic[SrcL, TgtL]): # TODO(ricoh): reconsider name in view of future backends producing standalone compilable ProgramSource code @dataclasses.dataclass(frozen=True) -class CompilableSource(Generic[SrcL, SettingT, TgtL]): +class CompilableProject(Generic[SrcL, SettingT, TgtL]): """ Encapsulate all the source code required for OTF compilation. diff --git a/src/gt4py/next/otf/toolchain.py b/src/gt4py/next/otf/toolchain.py index 4b7bd0b7ef..0c816759ff 100644 --- a/src/gt4py/next/otf/toolchain.py +++ b/src/gt4py/next/otf/toolchain.py @@ -15,52 +15,52 @@ from gt4py.next.otf import workflow -PrgT = typing.TypeVar("PrgT") -ArgT = typing.TypeVar("ArgT") -StartT = typing.TypeVar("StartT") -EndT = typing.TypeVar("EndT") +S = typing.TypeVar("S") +T = typing.TypeVar("T") +DefT = typing.TypeVar("DefT") +ArgsT = typing.TypeVar("ArgsT") @dataclasses.dataclass -class CompilableProgram(Generic[PrgT, ArgT]): - data: PrgT - args: ArgT +class ConcreteArtifact(Generic[DefT, ArgsT]): + data: DefT + args: ArgsT @dataclasses.dataclass(frozen=True) class DataOnlyAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, - workflow.Workflow[CompilableProgram[StartT, ArgT], CompilableProgram[EndT, ArgT]], - Generic[ArgT, StartT, EndT], + workflow.Workflow[ConcreteArtifact[S, ArgsT], ConcreteArtifact[T, ArgsT]], + Generic[ArgsT, S, T], ): - step: workflow.Workflow[StartT, EndT] + step: workflow.Workflow[S, T] - def __call__(self, inp: CompilableProgram[StartT, ArgT]) -> CompilableProgram[EndT, ArgT]: - return CompilableProgram(data=self.step(inp.data), args=inp.args) + def __call__(self, inp: ConcreteArtifact[S, ArgsT]) -> ConcreteArtifact[T, ArgsT]: + return ConcreteArtifact(data=self.step(inp.data), args=inp.args) @dataclasses.dataclass(frozen=True) class ArgsOnlyAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, - workflow.Workflow[CompilableProgram[PrgT, StartT], CompilableProgram[PrgT, EndT]], - Generic[PrgT, StartT, EndT], + workflow.Workflow[ConcreteArtifact[DefT, S], ConcreteArtifact[DefT, T]], + Generic[DefT, S, T], ): - step: workflow.Workflow[StartT, EndT] + step: workflow.Workflow[S, T] - def __call__(self, inp: CompilableProgram[PrgT, StartT]) -> CompilableProgram[PrgT, EndT]: - return CompilableProgram(data=inp.data, args=self.step(inp.args)) + def __call__(self, inp: ConcreteArtifact[DefT, S]) -> ConcreteArtifact[DefT, T]: + return ConcreteArtifact(data=inp.data, args=self.step(inp.args)) @dataclasses.dataclass(frozen=True) class StripArgsAdapter( workflow.ChainableWorkflowMixin, workflow.ReplaceEnabledWorkflowMixin, - workflow.Workflow[CompilableProgram[StartT, ArgT], EndT], - Generic[ArgT, StartT, EndT], + workflow.Workflow[ConcreteArtifact[S, ArgsT], T], + Generic[ArgsT, S, T], ): - step: workflow.Workflow[StartT, EndT] + step: workflow.Workflow[S, T] - def __call__(self, inp: CompilableProgram[StartT, ArgT]) -> EndT: + def __call__(self, inp: ConcreteArtifact[S, ArgsT]) -> T: return self.step(inp.data) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 0c76757d70..ac5a42f8d4 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -21,7 +21,7 @@ from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import pass_manager -from gt4py.next.otf import languages, stages, step_types, workflow +from gt4py.next.otf import definitions, languages, stages, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering @@ -39,11 +39,11 @@ def get_param_description(name: str, type_: Any) -> interface.Parameter: @dataclasses.dataclass(frozen=True) class GTFNTranslationStep( workflow.ReplaceEnabledWorkflowMixin[ - stages.CompilableProgram, + definitions.CompilableProgramDef, stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings], ], workflow.ChainableWorkflowMixin[ - stages.CompilableProgram, + definitions.CompilableProgramDef, stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings], ], ): @@ -206,7 +206,7 @@ def generate_stencil_source( return codegen.format_source("cpp", generated_code, style="LLVM") def __call__( - self, inp: stages.CompilableProgram + self, inp: definitions.CompilableProgramDef ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" program: itir.Program = inp.data @@ -317,8 +317,8 @@ class Meta: model = GTFNTranslationStep -translate_program_cpu: Final[step_types.TranslationStep] = GTFNTranslationStepFactory() # type: ignore[assignment] # factory-boy typing not precise enough +translate_program_cpu: Final[definitions.TranslationStep] = GTFNTranslationStepFactory() # type: ignore[assignment] # factory-boy typing not precise enough -translate_program_gpu: Final[step_types.TranslationStep] = GTFNTranslationStepFactory( # type: ignore[assignment] # factory-boy typing not precise enough +translate_program_gpu: Final[definitions.TranslationStep] = GTFNTranslationStepFactory( # type: ignore[assignment] # factory-boy typing not precise enough device_type=core_defs.DeviceType.CUDA ) diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index bcb11953cf..f8c8fd84a3 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -44,7 +44,7 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: # TODO(ricoh): connectivity tables required here for now. gtir_stage = typing.cast(gtx_backend.Transforms, self.backend.transforms).past_to_itir( - toolchain.CompilableProgram( + toolchain.ConcreteArtifact( data=self.past_stage, args=arguments.CompileTimeArgs( args=tuple(p.type for p in self.past_stage.past_node.params), diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py b/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py index fbdc8229ec..0fdd51959e 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/bindings.py @@ -288,13 +288,13 @@ def _create_sdfg_bindings( def bind_sdfg( inp: stages.ProgramSource[languages.SDFG, languages.LanguageSettings], bind_func_name: str, -) -> stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python]: +) -> stages.CompilableProject[languages.SDFG, languages.LanguageSettings, languages.Python]: """ Method to be used as workflow stage for generation of SDFG bindings. Refer to `_create_sdfg_bindings` documentation. """ - return stages.CompilableSource( + return stages.CompilableProject( program_source=inp, binding_source=_create_sdfg_bindings(inp, bind_func_name), ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index d5e42ef181..6ebaab04b0 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -19,7 +19,7 @@ from gt4py._core import definitions as core_defs, locking from gt4py.next import common, config -from gt4py.next.otf import languages, stages, step_types, workflow +from gt4py.next.otf import definitions, languages, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon @@ -117,14 +117,14 @@ def __call__(self, **kwargs: Any) -> None: @dataclasses.dataclass(frozen=True) class DaCeCompiler( workflow.ChainableWorkflowMixin[ - stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + stages.CompilableProject[languages.SDFG, languages.LanguageSettings, languages.Python], CompiledDaceProgram, ], workflow.ReplaceEnabledWorkflowMixin[ - stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + stages.CompilableProject[languages.SDFG, languages.LanguageSettings, languages.Python], CompiledDaceProgram, ], - step_types.CompilationStep[languages.SDFG, languages.LanguageSettings, languages.Python], + definitions.CompilationStep[languages.SDFG, languages.LanguageSettings, languages.Python], ): """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" @@ -135,7 +135,7 @@ class DaCeCompiler( def __call__( self, - inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + inp: stages.CompilableProject[languages.SDFG, languages.LanguageSettings, languages.Python], ) -> CompiledDaceProgram: with gtx_wfdcommon.dace_context( device_type=self.device_type, diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 2b94578756..ad771d2d26 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -18,7 +18,7 @@ from gt4py.next import common, config from gt4py.next.instrumentation import metrics from gt4py.next.iterator import ir as itir, transforms as itir_transforms -from gt4py.next.otf import languages, stages, step_types, workflow +from gt4py.next.otf import definitions, languages, stages, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings from gt4py.next.program_processors.runners.dace import ( @@ -356,9 +356,10 @@ def make_sdfg_call_sync(sdfg: dace.SDFG, gpu: bool) -> None: @dataclasses.dataclass(frozen=True) class DaCeTranslator( workflow.ChainableWorkflowMixin[ - stages.CompilableProgram, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] + definitions.CompilableProgramDef, + stages.ProgramSource[languages.SDFG, languages.LanguageSettings], ], - step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], + definitions.TranslationStep[languages.SDFG, languages.LanguageSettings], ): device_type: core_defs.DeviceType auto_optimize: bool @@ -440,7 +441,7 @@ def _generate_sdfg_without_configuring_dace( return sdfg def __call__( - self, inp: stages.CompilableProgram + self, inp: definitions.CompilableProgramDef ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the GTIR definition.""" program: itir.Program = inp.data diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 038f2959b6..5fd5b693a7 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -136,7 +136,7 @@ class Params: translation = factory.LazyAttribute(lambda o: o.bare_translation) - bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] = ( nanobind.bind_source ) compilation = factory.SubFactory( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 0d6eb23fda..2e8a90be9c 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -23,7 +23,7 @@ from gt4py.next import allocators as next_allocators, backend as next_backend, common, config from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.iterator import ir as itir, transforms as itir_transforms -from gt4py.next.otf import stages, workflow +from gt4py.next.otf import definitions, stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts @@ -208,13 +208,13 @@ def fencil_generator( @dataclasses.dataclass(frozen=True) -class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgram]): +class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, stages.CompiledProgram]): debug: Optional[bool] = None use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` - def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: + def __call__(self, inp: definitions.CompilableProgramDef) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug fencil = fencil_generator( diff --git a/src/gt4py/next/typing.py b/src/gt4py/next/typing.py index 7cbb3608fd..696b32e477 100644 --- a/src/gt4py/next/typing.py +++ b/src/gt4py/next/typing.py @@ -17,6 +17,7 @@ _ONLY_FOR_TYPING: Final[str] = "only for typing" # TODO(havogt): alternatively we could introduce Protocols +GTEntryPoint: TypeAlias = Annotated[decorator.GTEntryPoint, _ONLY_FOR_TYPING] Program: TypeAlias = Annotated[decorator.Program, _ONLY_FOR_TYPING] FieldOperator: TypeAlias = Annotated[decorator.FieldOperator, _ONLY_FOR_TYPING] Backend: TypeAlias = Annotated[backend.Backend, _ONLY_FOR_TYPING] @@ -24,6 +25,7 @@ allocators.FieldBufferAllocationUtil, _ONLY_FOR_TYPING ] + __all__ = [ "Backend", "FieldBufferAllocationUtil", diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 46b40085d5..c3c098a057 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -108,7 +108,7 @@ def make_builtin_field_operator(builtin_name: str, backend: Optional[next_backen return decorator.FieldOperatorFromFoast( definition_stage=None, - foast_stage=ffront_stages.FoastOperatorDefinition( + foast_stage=ffront_stages.FOASTOperatorDef( foast_node=typed_foast_node, closure_vars=closure_vars, grid_type=None, diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py index c1503f3e7c..4ca940db83 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_stages.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py @@ -94,15 +94,15 @@ def test_fingerprint_stage_field_op_def(fieldop, samecode_fieldop, different_fie def test_fingerprint_stage_foast_op_def(fieldop, samecode_fieldop, different_fieldop): foast = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( - toolchain.CompilableProgram(fieldop.definition_stage, arguments.CompileTimeArgs.empty()) + toolchain.ConcreteArtifact(fieldop.definition_stage, arguments.CompileTimeArgs.empty()) ).data samecode = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( - toolchain.CompilableProgram( + toolchain.ConcreteArtifact( samecode_fieldop.definition_stage, arguments.CompileTimeArgs.empty() ) ).data different = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( - toolchain.CompilableProgram( + toolchain.ConcreteArtifact( different_fieldop.definition_stage, arguments.CompileTimeArgs.empty() ) ).data @@ -122,15 +122,15 @@ def test_fingerprint_stage_program_def(program, samecode_program, different_prog def test_fingerprint_stage_past_def(program, samecode_program, different_program): past = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( - toolchain.CompilableProgram(program.definition_stage, arguments.CompileTimeArgs.empty()) + toolchain.ConcreteArtifact(program.definition_stage, arguments.CompileTimeArgs.empty()) ) samecode = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( - toolchain.CompilableProgram( + toolchain.ConcreteArtifact( samecode_program.definition_stage, arguments.CompileTimeArgs.empty() ) ) different = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( - toolchain.CompilableProgram( + toolchain.ConcreteArtifact( different_program.definition_stage, arguments.CompileTimeArgs.empty() ) ) diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py index 97c848bea9..8399455c4e 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py @@ -90,7 +90,7 @@ def program_source_example(): @pytest.fixture def compilable_source_example(program_source_example): - return stages.CompilableSource( + return stages.CompilableProject( program_source=program_source_example, binding_source=nanobind.create_bindings(program_source_example), ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 51645fba02..1bedd22191 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -82,7 +82,7 @@ def _verify_program_has_expected_true_value(program: itir.Program): def test_inlining_of_scalars_works(testee_prog): - input_pair = toolchain.CompilableProgram( + input_pair = toolchain.ConcreteArtifact( data=testee_prog.definition_stage, args=arguments.CompileTimeArgs( args=list(testee_prog.past_stage.past_node.type.definition.pos_or_kw_args.values()), @@ -108,7 +108,7 @@ def test_inlining_of_scalar_works_integration(testee_prog): hijacked_program = None - def pirate(program: toolchain.CompilableProgram): + def pirate(program: toolchain.ConcreteArtifact): # Replaces the gtfn otf_workflow: and steals the compilable program, # then returns a dummy "CompiledProgram" that does nothing. nonlocal hijacked_program diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 7f759ef504..a55578d801 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -15,7 +15,7 @@ import gt4py.next as gtx from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.otf import arguments, languages, stages +from gt4py.next.otf import arguments, languages, stages, definitions from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.program_processors.runners import gtfn from gt4py.next.type_system import type_translation @@ -75,7 +75,7 @@ def program_example(): def test_codegen(program_example): fencil, parameters = program_example module = gtfn_module.translate_program_cpu( - stages.CompilableProgram( + definitions.CompilableProgramDef( data=fencil, args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) @@ -87,7 +87,7 @@ def test_codegen(program_example): def test_hash_and_diskcache(program_example, tmp_path): fencil, parameters = program_example - compilable_program = stages.CompilableProgram( + compilable_program = definitions.CompilableProgramDef( data=fencil, args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) @@ -129,7 +129,7 @@ def test_hash_and_diskcache(program_example, tmp_path): def test_gtfn_file_cache(program_example): fencil, parameters = program_example - compilable_program = stages.CompilableProgram( + compilable_program = definitions.CompilableProgramDef( data=fencil, args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py index f17e0ae57a..5923ead6c6 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py @@ -225,7 +225,7 @@ def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args, offset_provid def mocked_compile_call( self, - inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + inp: stages.CompilableProject[languages.SDFG, languages.LanguageSettings, languages.Python], binding_source_ref: str, ): assert len(inp.library_deps) == 0 @@ -242,7 +242,7 @@ def mocked_compile_call( def mocked_compile_call_cartesian( self, - inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + inp: stages.CompilableProject[languages.SDFG, languages.LanguageSettings, languages.Python], use_metrics: bool, use_zero_origin: bool, ): @@ -254,7 +254,7 @@ def mocked_compile_call_cartesian( def mocked_compile_call_unstructured( self, - inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + inp: stages.CompilableProject[languages.SDFG, languages.LanguageSettings, languages.Python], use_metrics: bool, use_zero_origin: bool, ):