Skip to content

Commit 68df316

Browse files
[custom ops] Begin the scaffolding for dispatch of PyTorch custom ops. (#270)
This makes it possible for us to directly define regular torch ops in terms of generated MLIR. The resulting ops will be specialized and cached per requirements in their definition and will be compiled for any device that Turbine supports when dispatched against tensors on that device. It is left to a follow-up to also wire this mechanism in on the AOT side so that compiling programs that contain our own custom ops transparently includes them with no further glue. The scaffolding for this is in place, but this patch is big enough without touching AOT. This allows users to say something like: ``` @CustomOp.register class identity(CustomOp): name = "test_identity" signature = "(Tensor self) -> Tensor" def select(self, ksel: KernelSelection): x = ksel.arg_tensor(0) ksel.return_tensor(x.t) def generate(self, ksel: KernelSelection, kb: KernelBuilder): # This just yields the IR value of kernel input as the output. # Effectively in eager mode, this is a `return` from the kernel # function. kb.yield_results(kb.arg_bindings[0]) t = torch.tensor([[1, 2, 3]], dtype=torch.int32) result = identity(t) print("CPU result:", result) torch.testing.assert_close(result, t) ``` There will be dedicated `CustomOp` subclasses for our various DSLs that can be used for such things (for more sugar'd use than just open coding IR).
1 parent 5d9d08b commit 68df316

File tree

28 files changed

+1490
-319
lines changed

28 files changed

+1490
-319
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
4747
- name: Run tests
4848
run: |
49-
pytest tests/
49+
pytest -n 4 tests/
5050
5151
black:
5252
strategy:

python/shark_turbine/aot/builtins/jittable.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@
3131
FxImporter,
3232
)
3333

34+
from ...support.ir_imports import (
35+
FlatSymbolRefAttr,
36+
FunctionType,
37+
Operation,
38+
StringAttr,
39+
SymbolTable,
40+
TypeAttr,
41+
Value,
42+
func_d,
43+
util_d,
44+
)
45+
3446
from ..passes import (
3547
functorch_functionalize,
3648
)
@@ -53,18 +65,6 @@
5365
MaterializedGlobal,
5466
)
5567

56-
from ..support.ir_imports import (
57-
FlatSymbolRefAttr,
58-
FunctionType,
59-
Operation,
60-
StringAttr,
61-
SymbolTable,
62-
TypeAttr,
63-
Value,
64-
func_d,
65-
util_d,
66-
)
67-
6868
StringAttrOrStr = Union[StringAttr, str]
6969

7070

python/shark_turbine/aot/compiled_module.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,7 @@
1717

1818
from . import builtins
1919

20-
from .support.procedural import (
21-
GlobalsDef,
22-
ProcedureTrace,
23-
current_ir_trace,
24-
)
25-
26-
from .support.ir_imports import (
20+
from ..support.ir_imports import (
2721
Context,
2822
Location,
2923
MLIRError,
@@ -33,6 +27,12 @@
3327
StringAttr,
3428
)
3529

30+
from .support.procedural import (
31+
GlobalsDef,
32+
ProcedureTrace,
33+
current_ir_trace,
34+
)
35+
3636
from .support.ir_utils import (
3737
ModuleBuilder,
3838
)

python/shark_turbine/aot/exporter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@
1919
Output,
2020
)
2121

22+
from ..support.ir_imports import (
23+
Context,
24+
Operation,
25+
)
26+
2227
from .builtins import *
2328
from .compiled_module import (
2429
CompiledModule,
2530
CompiledModuleMeta,
2631
ExportProcDef,
2732
)
28-
from .support.ir_imports import (
29-
Context,
30-
Operation,
31-
)
3233
from .support.procedural import (
3334
AbstractTypedef,
3435
)

python/shark_turbine/aot/support/ir_utils.py

Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# See https://llvm.org/LICENSE.txt for license information.
66
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
77

8-
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple
8+
from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple
99

1010
from pathlib import Path
1111
import tempfile
@@ -15,7 +15,6 @@
1515

1616
from ...importers.fx_importer import (
1717
ContextCache,
18-
TORCH_DTYPE_TO_MLIR_TYPE_ASM,
1918
)
2019

2120
from ...importers.utils import (
@@ -26,12 +25,9 @@
2625
NativeTypeConverter,
2726
)
2827

29-
from .ir_imports import (
28+
from ...support.ir_imports import (
3029
Attribute,
31-
Block,
32-
BlockArgument,
3330
BF16Type,
34-
ComplexType,
3531
DenseElementsAttr,
3632
DenseResourceElementsAttr,
3733
F16Type,
@@ -46,7 +42,6 @@
4642
IrType,
4743
Location,
4844
MLIRError,
49-
OpResult,
5045
Operation,
5146
RankedTensorType,
5247
StringAttr,
@@ -59,61 +54,21 @@
5954
tensor_d,
6055
)
6156

57+
from ...support.conversions import (
58+
TORCH_DTYPE_TO_IREE_TYPE,
59+
)
60+
6261
from .utils import (
6362
RefTracker,
6463
logger,
6564
)
6665

67-
###############################################################################
68-
# Lookup tables
69-
###############################################################################
70-
71-
# We need the inverse of the TORCH_DTYPE_TO_MLIR_TYPE_ASM table.
72-
MLIR_TYPE_ASM_TO_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_TO_MLIR_TYPE_ASM.items()}
73-
74-
# When emitting constants, we have to create native IREE types.
75-
TORCH_DTYPE_TO_IREE_TYPE: Dict[torch.dtype, Callable[[], IrType]] = {
76-
torch.float16: lambda: F16Type.get(),
77-
torch.bfloat16: lambda: BF16Type.get(),
78-
torch.float32: lambda: F32Type.get(),
79-
torch.float64: lambda: F64Type.get(),
80-
torch.uint8: lambda: IntegerType.get_signless(8),
81-
torch.int8: lambda: IntegerType.get_signless(8),
82-
torch.int16: lambda: IntegerType.get_signless(16),
83-
torch.int32: lambda: IntegerType.get_signless(32),
84-
torch.int64: lambda: IntegerType.get_signless(64),
85-
torch.bool: lambda: IntegerType.get_signless(1),
86-
torch.qint8: lambda: IntegerType.get_signless(8),
87-
torch.quint8: lambda: IntegerType.get_signless(8),
88-
torch.complex32: lambda: ComplexType.get(F16Type.get()),
89-
torch.complex64: lambda: ComplexType.get(F32Type.get()),
90-
torch.complex128: lambda: ComplexType.get(F64Type.get()),
91-
}
92-
93-
TORCH_DTYPE_TO_IREE_TYPE_ASM = {
94-
torch.float16: "f16",
95-
torch.bfloat16: "bf16",
96-
torch.float32: "f32",
97-
torch.float64: "f64",
98-
torch.uint8: "i8",
99-
torch.int8: "i8",
100-
torch.int16: "i16",
101-
torch.int32: "i32",
102-
torch.int64: "i64",
103-
torch.bool: "i1",
104-
torch.qint8: "i8",
105-
torch.quint8: "i8",
106-
torch.complex32: "complex<f16>",
107-
torch.complex64: "complex<f32>",
108-
torch.complex128: "complex<f64>",
109-
}
110-
11166
###############################################################################
11267
# Configuration
11368
###############################################################################
11469

11570
# Maps a name to an altered name. If returns None, then the original
116-
# name is used (this lets Dict.get serve as a NameMapCallback).
71+
# name is used (this lets dict.get serve as a NameMapCallback).
11772
NameMapCallback = Callable[[str], Optional[str]]
11873

11974

@@ -420,7 +375,7 @@ def build_index_attribute(value: int) -> IntegerAttr:
420375

421376

422377
def build_index_value(
423-
value: int, constant_cache: Optional[Dict[int, Value]] = None
378+
value: int, constant_cache: Optional[dict[int, Value]] = None
424379
) -> Value:
425380
if constant_cache is not None and value in constant_cache:
426381
return constant_cache[value]
@@ -431,7 +386,7 @@ def build_index_value(
431386

432387

433388
def build_tensor_dim_value(
434-
t: Value, dim: int, constant_cache: Optional[Dict[int, Value]] = None
389+
t: Value, dim: int, constant_cache: Optional[dict[int, Value]] = None
435390
) -> Value:
436391
dim_value = build_index_value(dim, constant_cache=constant_cache)
437392
return tensor_d.DimOp(t, dim_value).result

python/shark_turbine/aot/support/procedural/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919

20-
from ..ir_imports import (
20+
from ....support.ir_imports import (
2121
F32Type,
2222
F64Type,
2323
IndexType,

python/shark_turbine/aot/support/procedural/globals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import torch
2121

22-
from ..ir_imports import (
22+
from ....support.ir_imports import (
2323
IrType,
2424
Operation,
2525
Value,

python/shark_turbine/aot/support/procedural/iree_emitter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414

15-
from ..ir_imports import (
15+
from ....support.ir_imports import (
1616
IndexType,
1717
IntegerType,
1818
IrType,
@@ -23,8 +23,11 @@
2323
flow_d,
2424
)
2525

26-
from ..ir_utils import (
26+
from ....support.conversions import (
2727
TORCH_DTYPE_TO_IREE_TYPE,
28+
)
29+
30+
from ..ir_utils import (
2831
build_index_value,
2932
)
3033

python/shark_turbine/aot/support/procedural/primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
dynamic_dim,
2525
)
2626

27-
from ..ir_imports import (
27+
from ....support.ir_imports import (
2828
F32Type,
2929
IrType,
3030
RankedTensorType,

python/shark_turbine/aot/support/procedural/tracer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
Sequence,
1515
)
1616

17-
from ..ir_imports import (
17+
from ....support.ir_imports import (
1818
Location,
1919
StringAttr,
2020
Value,

python/shark_turbine/dynamo/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77

8-
from .device import Device
98
from .tensor import (
109
enable,
1110
TurbineMode,

python/shark_turbine/dynamo/backends/cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import functools
88
import sys
99

10-
from ..device import (
10+
from ...runtime.device import (
1111
DeviceState,
1212
)
1313

0 commit comments

Comments
 (0)