Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start adding the library of custom ops. #296

Merged
merged 6 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ def create_tensor_global(
array = np.array(detached_tensor)
# We know that a Numpy array is a ReadableBuffer so ignore type error.
contents = memoryview(array) # type: ignore
shape_desc = "_".join([str(d) for d in t.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(t.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
contents, "from_py", tensor_type
contents, blob_name, tensor_type
)
ir_attrs["initial_value"] = elements_attr

Expand Down
23 changes: 16 additions & 7 deletions python/shark_turbine/dynamo/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, context: Context):
self.torch_type_to_native
)

def torch_type_to_native(self, torch_type: IrType) -> IrType:
def torch_type_to_native(self, torch_type: IrType, signless: bool = True) -> IrType:
"""Converts a presumed torch type to a corresponding native type.

This mirrors the type conversion in torch-mlir's BackendTypeConversion.cpp.
Expand All @@ -56,6 +56,8 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
!torch.float -> f64
!torch.bool -> i1
!torch.vtensor -> tensor

If `signless=False`, then integer types will retain their signs.
"""
# We don't presently have API support for introspecting torch type,
# and even if we did, it is likely that this is more efficient.
Expand All @@ -66,7 +68,11 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
if name == "bool":
return IntegerType.get_signless(1)
if name == "int":
return IntegerType.get_signless(64)
return (
IntegerType.get_signless(64)
if signless
else IntegerType.get_signed(64)
)
elif name == "float":
return F64Type.get()
elif name == "vtensor":
Expand All @@ -75,22 +81,25 @@ def torch_type_to_native(self, torch_type: IrType) -> IrType:
dim_list_str, dtype_str = tm.groups()
dim_list = parse_tensor_dim_list(dim_list_str)
dtype = self.convert_torch_element_type_to_native(
IrType.parse(dtype_str)
IrType.parse(dtype_str), signless=signless
)
# TODO: Eliminate RankedTensorType dependence on Location.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/145
with Location.unknown():
return RankedTensorType.get(dim_list, dtype)
raise TypeError(f"Unsupported torch type conversion for {torch_type}")

def convert_torch_element_type_to_native(self, torch_type: IrType) -> IrType:
def convert_torch_element_type_to_native(
self, torch_type: IrType, signless: bool = True
) -> IrType:
# Torch uses the builtin type hierarchy of IntegerType and FloatType
# to represent dtypes. These are mostly the same, but it always uses
# signed IntegerTypes which we must convert to signless for the native
# type system.
if IntegerType.isinstance(torch_type):
signed_int_type = IntegerType(torch_type)
return IntegerType.get_signless(signed_int_type.width)
if signless:
if IntegerType.isinstance(torch_type):
signed_int_type = IntegerType(torch_type)
return IntegerType.get_signless(signed_int_type.width)
return torch_type

def materialize_native_to_torch(
Expand Down
7 changes: 7 additions & 0 deletions python/shark_turbine/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright 2023 Advanced Micro Devices, Inc
stellaraccident marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from . import iree
69 changes: 69 additions & 0 deletions python/shark_turbine/ops/iree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023 Advanced Micro Devices, Inc
stellaraccident marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..support.ir_imports import (
RankedTensorType,
StringAttr,
Value,
flow_d,
tensor_d,
)

from ..runtime.op_reg import (
CustomOp,
KernelBuilder,
KernelSelection,
def_library,
)

__all__ = [
"trace",
]

IREE_LIBRARY = def_library("iree")
stellaraccident marked this conversation as resolved.
Show resolved Hide resolved


################################################################################
# trace_tensor / trace_tensors
################################################################################


def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]):
dynamic_dims = []
for t in ts:
rtt = RankedTensorType(t.type)
for i in range(rtt.rank):
if rtt.is_dynamic_dim(i):
dynamic_dims.append(tensor_d.dim(t, kb.constant_index(i)))
flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims)


@CustomOp.register(library=IREE_LIBRARY)
class trace_tensor(CustomOp):
stellaraccident marked this conversation as resolved.
Show resolved Hide resolved
signature = "trace_tensor(str trace_key, Tensor tensor) -> ()"

def select(self, ksel: KernelSelection):
ksel.attr_str(0)
ksel.arg_tensor(1)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
_emit_tensor_trace(kb, ksel.arg_descs[0].v, [kb.arg_bindings[1]])
kb.yield_results()


@CustomOp.register(library=IREE_LIBRARY)
class trace_tensors(CustomOp):
signature = "trace_tensors(str trace_key, Tensor[] tensors) -> ()"

def select(self, ksel: KernelSelection):
ksel.attr_str(0)
ksel.arg_tensor_list(1)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
ts = kb.arg_bindings[1]
if len(ts) >= 1:
_emit_tensor_trace(kb, ksel.arg_descs[0].v, ts)
kb.yield_results()
2 changes: 1 addition & 1 deletion python/shark_turbine/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBuffe
memory_type=MemoryType.DEVICE_LOCAL,
allowed_usage=BufferUsage.DEFAULT,
device=hal_device,
buffer=t.numpy(),
buffer=t.detach().numpy(),
element_type=element_type,
)
return bv
Expand Down
Loading
Loading