Skip to content
10 changes: 5 additions & 5 deletions docs/user/next/advanced/HackTheToolchain.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import typing

from gt4py import next as gtx
from gt4py.next.otf import toolchain, workflow
from gt4py.next.ffront import stages as ff_stages
from gt4py.next.ffront import field_operator_ast as foast, stages as ff_stages
from gt4py import eve
```

Expand All @@ -22,8 +22,8 @@ cached_lowering_toolchain = gtx.backend.DEFAULT_TRANSFORMS.replace(
## Skip Steps / Change Order

```python
DUMMY_FOP = toolchain.CompilableProgram(
data=ff_stages.FieldOperatorDefinition(definition=None), args=None
DUMMY_FOP = toolchain.ConcreteArtifact(
data=ff_stages.DSLFieldOperatorDef(definition=None), args=None
)
```

Expand Down Expand Up @@ -57,9 +57,9 @@ class Cpp2BindingsGen: ...

class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWorkflowFactory):
translation: workflow.Workflow[
gtx.otf.stages.CompilableProgram, gtx.otf.stages.ProgramSource
gtx.otf.definitions.CompilableProgramDef, gtx.otf.stages.ProgramSource
] = MyCodeGen()
bindings: workflow.Workflow[gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableSource] = (
bindings: workflow.Workflow[gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableProject] = (
Cpp2BindingsGen()
)

Expand Down
110 changes: 59 additions & 51 deletions src/gt4py/next/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

import dataclasses
import typing
from typing import Generic

from gt4py._core import definitions as core_defs
Expand All @@ -21,30 +20,34 @@
func_to_past,
past_process_args,
past_to_itir,
stages as ffront_stages,
)
from gt4py.next.ffront.past_passes import linters as past_linters
from gt4py.next.ffront.stages import (
AOT_DSL_FOP,
AOT_DSL_PRG,
AOT_FOP,
AOT_PRG,
DSL_FOP,
DSL_PRG,
FOP,
PAST_PRG,
)
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import arguments, stages, toolchain, workflow
from gt4py.next.otf import arguments, definitions, stages, toolchain, workflow


def jit_to_aot_args(
inp: arguments.JITArgs,
) -> arguments.CompileTimeArgs:
return arguments.CompileTimeArgs.from_concrete(*inp.args, **inp.kwargs)

IRDefinitionForm: typing.TypeAlias = DSL_FOP | FOP | DSL_PRG | PAST_PRG | itir.Program
CompilableDefinition: typing.TypeAlias = toolchain.CompilableProgram[
IRDefinitionForm, arguments.JITArgs | arguments.CompileTimeArgs
]

def adapted_jit_to_aot_args_factory() -> workflow.Workflow[
definitions.ConcreteProgramDef[definitions.IRDefinitionT, arguments.JITArgs],
definitions.ConcreteProgramDef[definitions.IRDefinitionT, arguments.CompileTimeArgs],
]:
"""Wrap `jit_to_aot` into a workflow adapter to fit into backend transform workflows."""
return toolchain.ArgsOnlyAdapter(jit_to_aot_args)


@dataclasses.dataclass(frozen=True)
class Transforms(workflow.MultiWorkflow[CompilableDefinition, stages.CompilableProgram]):
class Transforms(
workflow.MultiWorkflow[
definitions.ConcreteProgramDef[definitions.IRDefinitionT, definitions.ArgsDefinitionT],
definitions.CompilableProgramDef,
]
):
"""
Modular workflow for transformations with access to intermediates.

Expand All @@ -60,44 +63,44 @@ class Transforms(workflow.MultiWorkflow[CompilableDefinition, stages.CompilableP
"""

aotify_args: workflow.Workflow[
toolchain.CompilableProgram[IRDefinitionForm, arguments.JITArgs],
toolchain.CompilableProgram[IRDefinitionForm, arguments.CompileTimeArgs],
] = dataclasses.field(default_factory=arguments.adapted_jit_to_aot_args_factory)
definitions.ConcreteProgramDef[definitions.IRDefinitionT, arguments.JITArgs],
definitions.ConcreteProgramDef[definitions.IRDefinitionT, arguments.CompileTimeArgs],
] = dataclasses.field(default_factory=adapted_jit_to_aot_args_factory)

func_to_foast: workflow.Workflow[AOT_DSL_FOP, AOT_FOP] = dataclasses.field(
default_factory=func_to_foast.adapted_func_to_foast_factory
)
func_to_foast: workflow.Workflow[
ffront_stages.ConcreteDSLFieldOperatorDef, ffront_stages.ConcreteFOASTOperatorDef
] = dataclasses.field(default_factory=func_to_foast.adapted_func_to_foast_factory)

func_to_past: workflow.Workflow[AOT_DSL_PRG, AOT_PRG] = dataclasses.field(
default_factory=func_to_past.adapted_func_to_past_factory
)
func_to_past: workflow.Workflow[
ffront_stages.ConcreteDSLProgramDef, ffront_stages.ConcretePASTProgramDef
] = dataclasses.field(default_factory=func_to_past.adapted_func_to_past_factory)

foast_to_itir: workflow.Workflow[AOT_FOP, itir.FunctionDefinition] = dataclasses.field(
default_factory=foast_to_gtir.adapted_foast_to_gtir_factory
)
foast_to_itir: workflow.Workflow[
ffront_stages.ConcreteFOASTOperatorDef, itir.FunctionDefinition
] = dataclasses.field(default_factory=foast_to_gtir.adapted_foast_to_gtir_factory)

field_view_op_to_prog: workflow.Workflow[AOT_FOP, AOT_PRG] = dataclasses.field(
default_factory=foast_to_past.operator_to_program_factory
)
field_view_op_to_prog: workflow.Workflow[
ffront_stages.ConcreteFOASTOperatorDef, ffront_stages.ConcretePASTProgramDef
] = dataclasses.field(default_factory=foast_to_past.operator_to_program_factory)

past_lint: workflow.Workflow[AOT_PRG, AOT_PRG] = dataclasses.field(
default_factory=past_linters.adapted_linter_factory
)
past_lint: workflow.Workflow[
ffront_stages.ConcretePASTProgramDef, ffront_stages.ConcretePASTProgramDef
] = dataclasses.field(default_factory=past_linters.adapted_linter_factory)

field_view_prog_args_transform: workflow.Workflow[AOT_PRG, AOT_PRG] = dataclasses.field(
default_factory=past_process_args.transform_program_args_factory
)
field_view_prog_args_transform: workflow.Workflow[
ffront_stages.ConcretePASTProgramDef, ffront_stages.ConcretePASTProgramDef
] = dataclasses.field(default_factory=past_process_args.transform_program_args_factory)

past_to_itir: workflow.Workflow[AOT_PRG, stages.CompilableProgram] = dataclasses.field(
default_factory=past_to_itir.past_to_gtir_factory
)
past_to_itir: workflow.Workflow[
ffront_stages.ConcretePASTProgramDef, definitions.CompilableProgramDef
] = dataclasses.field(default_factory=past_to_itir.past_to_gtir_factory)

def step_order(self, inp: CompilableDefinition) -> list[str]:
def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]:
steps: list[str] = []
if isinstance(inp.args, arguments.JITArgs):
steps.append("aotify_args")
match inp.data:
case DSL_FOP():
case ffront_stages.DSLFieldOperatorDef():
steps.extend(
[
"func_to_foast",
Expand All @@ -107,7 +110,7 @@ def step_order(self, inp: CompilableDefinition) -> list[str]:
"past_to_itir",
]
)
case FOP():
case ffront_stages.FOASTOperatorDef():
steps.extend(
[
"field_view_op_to_prog",
Expand All @@ -116,11 +119,16 @@ def step_order(self, inp: CompilableDefinition) -> list[str]:
"past_to_itir",
]
)
case DSL_PRG():
case ffront_stages.DSLProgramDef():
steps.extend(
["func_to_past", "past_lint", "field_view_prog_args_transform", "past_to_itir"]
[
"func_to_past",
"past_lint",
"field_view_prog_args_transform",
"past_to_itir",
]
)
case PAST_PRG():
case ffront_stages.PASTProgramDef():
steps.extend(["past_lint", "field_view_prog_args_transform", "past_to_itir"])
case itir.Program():
pass
Expand All @@ -139,15 +147,15 @@ def step_order(self, inp: CompilableDefinition) -> list[str]:
@dataclasses.dataclass(frozen=True)
class Backend(Generic[core_defs.DeviceTypeT]):
name: str
executor: workflow.Workflow[stages.CompilableProgram, stages.CompiledProgram]
executor: workflow.Workflow[definitions.CompilableProgramDef, stages.CompiledProgram]
allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
transforms: workflow.Workflow[CompilableDefinition, stages.CompilableProgram]
transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef]

def compile(
self, program: IRDefinitionForm, compile_time_args: arguments.CompileTimeArgs
self, program: definitions.IRDefinitionT, compile_time_args: arguments.CompileTimeArgs
) -> stages.CompiledProgram:
return self.executor(
self.transforms(toolchain.CompilableProgram(data=program, args=compile_time_args))
self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args))
)

@property
Expand Down
Loading