From 54598bc61cbc8e80c8a9595ed5fbd13f24eae469 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Mon, 16 Oct 2023 23:02:32 -0700 Subject: [PATCH 1/4] Add eager mode support for simple elemwise and linear layer. --- python/shark_turbine/aot/passes/functorch.py | 1 + python/shark_turbine/dynamo/__init__.py | 1 + python/shark_turbine/dynamo/executor.py | 65 +++++++ python/shark_turbine/dynamo/tensor.py | 174 ++++++++++++++++++- tests/dynamo/tensor_test.py | 48 +++++ 5 files changed, 286 insertions(+), 3 deletions(-) diff --git a/python/shark_turbine/aot/passes/functorch.py b/python/shark_turbine/aot/passes/functorch.py index 52f3feb66..7fbf067ce 100644 --- a/python/shark_turbine/aot/passes/functorch.py +++ b/python/shark_turbine/aot/passes/functorch.py @@ -43,6 +43,7 @@ def functorch_functionalize(gm: GraphModule, *args) -> GraphModule: functionalized_callable = _functionalize_callabale(gm) # TODO: There is more of a dance needed if the user has entered with a fake_mode. + # import pdb; pdb.set_trace() with proxy_tensor.maybe_disable_fake_tensor_mode(): new_gm = proxy_tensor.make_fx( functionalized_callable, diff --git a/python/shark_turbine/dynamo/__init__.py b/python/shark_turbine/dynamo/__init__.py index aa6d60d96..eff76341d 100644 --- a/python/shark_turbine/dynamo/__init__.py +++ b/python/shark_turbine/dynamo/__init__.py @@ -8,6 +8,7 @@ from .device import Device from .tensor import ( enable, + disable, TurbineMode, DeviceTensor, ) diff --git a/python/shark_turbine/dynamo/executor.py b/python/shark_turbine/dynamo/executor.py index 939ac1449..83f6474ec 100644 --- a/python/shark_turbine/dynamo/executor.py +++ b/python/shark_turbine/dynamo/executor.py @@ -99,3 +99,68 @@ def _returns_to_user(self, ret_list: VmVariantList): user_returns[i] = torch_from_numpy(host_array) return user_returns + + + +class EagerSpecializedExecutable: + """A concrete executable that has been specialized in some way.""" + + __slots__ = [ + "device_state", + "entry_function", + "user_module", + "vm_context", + ] + + def __init__( + self, + user_module: VmModule, + device_state: DeviceState, + entry_name: str = "main", + ): + self.user_module = user_module + self.vm_context = VmContext( + device_state.instance, + ( + create_hal_module(device_state.instance, device_state.device), + user_module, + ), + ) + self.device_state = device_state + self.entry_function = self.user_module.lookup_function(entry_name) + + def __call__(self, *inputs): + arg_list = VmVariantList(len(inputs)) + ret_list = VmVariantList( + 1 + ) # TODO: Get the number of results from the descriptor. + + # Move inputs to the device and add to arguments. + self._inputs_to_device(inputs, arg_list) + # TODO: Append semaphores for async execution. + + # Invoke. + self.vm_context.invoke(self.entry_function, arg_list, ret_list) + return self._returns_to_user(ret_list) + + def _inputs_to_device(self, inputs: list, arg_list: VmVariantList): + # TODO: We are assuming the worst case here which is that we have unknown Torch + # tensors that we send to the CPU and make continguous. Ideally, we would have + # fast paths for our own backends and interop. + for input in inputs: + arg_list.push_ref(input.buffer_view) + + def _returns_to_user(self, ret_list: VmVariantList): + # TODO: This is also not good that we are moving back to the CPU like this. + # We should be returning a custom Tensor implementation which represents + # our device data and has synchronization hooks for accessing it. + device = self.device_state.device + num_returns = len(ret_list) + user_returns = [None] * num_returns + for i in range(num_returns): + device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i)) + device_array = DeviceArray(device, device_buffer_view) + host_array = device_array.to_host() + user_returns[i] = torch_from_numpy(host_array) + + return user_returns diff --git a/python/shark_turbine/dynamo/tensor.py b/python/shark_turbine/dynamo/tensor.py index 859b379cb..5828ed2e1 100644 --- a/python/shark_turbine/dynamo/tensor.py +++ b/python/shark_turbine/dynamo/tensor.py @@ -12,14 +12,22 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple +import functools +import sys +import atexit from array import array import numpy as np import torch +import torch._dynamo as dynamo from torch.overrides import TorchFunctionMode from .device import ( Device, + DeviceState, +) +from .executor import ( + EagerSpecializedExecutable, ) from ..support import ( @@ -33,6 +41,20 @@ HalCommandBuffer, HalElementType, HalFence, + VmModule, +) + +from iree.compiler.api import Session, Output +from iree.compiler.passmanager import PassManager + +from .importer import FxImporter +from .passes import turbine_cpu_pass_pipeline + +DEFAULT_COMPILER_FLAGS = ( + # Enable asynchronous calling convention. + # TODO: Enable async execution mode. + # "--iree-execution-model=async-external", + "--iree-input-type=tm_tensor", ) @@ -49,15 +71,18 @@ class TurbineMode(TorchFunctionMode): """ IMPLEMENTATIONS = {} + COMPUTE_METHODS = set() + CACHED_IMPLEMENTATIONS = {} - def __torch_function__(self, func, types, args=(), kwargs=None): + def __torch_function__(self, func, types, args=(), kwargs={}): def super_fn(*args, **kwargs): # Disable torch_function by hand because we don't want the wrapping behavior of # the super() impl with torch._C.DisableTorchFunction(): return func(*args, **kwargs) - if func in self.IMPLEMENTATIONS: + if func in self.COMPUTE_METHODS: + args += (func,) return self.IMPLEMENTATIONS[func](super_fn, *args, **kwargs or {}) # This is just a no-op for all the non-factory functions: @@ -67,7 +92,12 @@ def super_fn(*args, **kwargs): def enable(): """Enables PyTorch tensor device= support for Turbine permanently.""" TurbineMode().__enter__() + Device("local-task").set() + atexit.register(disable) +def disable(): + Device.current().clear() + TurbineMode().__exit__(None, None, None) # Convenient wrapper to register functions def raw_factory(func): @@ -79,6 +109,17 @@ def _inner_fn(impl): return _inner_fn +# Convenient wrapper to register functions +def compute_factory(func): + """Decorator to register an unconditional factory function.""" + + def _inner_fn(impl): + TurbineMode.IMPLEMENTATIONS[func] = impl + TurbineMode.COMPUTE_METHODS.add(func) + return impl + + return _inner_fn + def device_factory(func): """Decorator to invoke the user provided factory for our devices. @@ -178,6 +219,16 @@ def __del__(self): class DeviceTensor(torch.Tensor): """A Tensor accessing memory on a Turbine device.""" + def __torch_dispatch__(cls, func, types, args=(), kwargs={}): + # Now, we check the function to determine how to handle it. If it's + # aten.add, then we call aten.sub. Otherwise, we pass through to + # the original function + args += (func,) + return compute_method(func, *args, **kwargs) + + def _to_meta_tensor(self): + return torch.empty(self.shape, dtype=self.dtype) + @staticmethod def __new__(cls, size, dtype, raw_data=None, requires_grad=False): # Using a meta tensor as the wrapped gives us shape and dtype @@ -197,6 +248,15 @@ def __init__(self, size, dtype, raw_data=None, requires_grad=False): raise NotImplementedError( f"raw_data= not implemented for DeviceTensor ({raw_data.__class__})" ) + @staticmethod + def from_torch(input_tensor : torch.Tensor): + if isinstance(input_tensor, torch.Tensor): + dev_tensor = DeviceTensor._async_create_empty(input_tensor.size(), Device("local-task"), input_tensor.dtype) + dev_tensor._async_copy_from_host(input_tensor.numpy()) + return dev_tensor + else: + if input_tensor is not None: + raise ValueError("Expected input to be of type torch.Tensor.") @property def buffer_view(self) -> HalBufferView: @@ -211,6 +271,10 @@ def buffer_view(self) -> HalBufferView: def cpu(self): return self.to("cpu") + @property + def device(self): + return "turbine" + def __repr__(self): hal_device = self._storage.device.hal_device try: @@ -289,7 +353,7 @@ def _calculate_c_contig_size(size: Sequence[int], dtype: torch.dtype) -> int: # And some factory functions # By hand @raw_factory(torch.Tensor.to) -def to(super_fn, self, device): +def to(super_fn, self, device, dtype=None, non_blocking=None): # Note that we only implement a subset of .to() here turbine_device = _parse_device(device) if turbine_device: @@ -315,6 +379,17 @@ def to(super_fn, self, device): else: return super_fn(self, device) +@raw_factory(torch._C._nn._parse_to) +def _parse_to(super_fn, *args, **kwargs): + if "turbine" in args: + #TODO: Parse through args and kwargs for correct params. + device = "turbine" + dtype = None + non_blocking = False + convert_to_format = None + return device, dtype, non_blocking, convert_to_format + else: + return super_fn(self, device) @device_factory(torch.empty) def _empty(*size, device: Device, dtype=torch.float32): @@ -372,6 +447,99 @@ def _rand(*args, dtype=None): return t +@functools.lru_cache(maxsize=None) +def _get_device_state() -> DeviceState: + return DeviceState(driver="local-task") + +# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/shark_turbine/aot/builtins/jittable.py#L212-L237 +# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/backends/cpu.py +# TODO: Try to generalize for other devices. +@compute_factory(torch.add) +@compute_factory(torch.sub) +@compute_factory(torch.mul) +@compute_factory(torch.abs) +def compute_method(super_fn, *args, **kwargs): + # Compute factory fns reserve the last arg as src_op + # Requires src_op rather than super_fn, because super_fn + # is often wrapped by DisableTorchFunction. + py_args = args[:-1] + src_op = args[-1] + + # Check if turbine and if all devices are the same. + py_args =[DeviceTensor.from_torch(torch.tensor(py_arg)) if isinstance(py_arg, (int, float)) else py_arg for py_arg in py_args] + devices = {tensor.device for tensor in py_args} + if "turbine" not in devices: + super_fn(*py_args, **kwargs) + if len(devices) > 1: + raise ValueError("Turbine do not support mixed device!") + # Get a unique encoding to identify computation/dispatch using opCode, input shapes, and dtypes. + compute_id_encode = str(src_op) + "".join([str(py_arg.shape) + str(py_arg.dtype) for py_arg in py_args]) + compute_hash = hash(compute_id_encode) + if compute_hash in TurbineMode.CACHED_IMPLEMENTATIONS: + # TODO: Handle multiple output. + return DeviceTensor.from_torch(TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash](*py_args, **kwargs)[0]) + + # Preprocess func and generate into FX. + flat_pytorch_args = [py_arg._to_meta_tensor() for py_arg in py_args] + # TODO: Replace all the below with torch.compile, although currently seems like + # the problem lies in it will try to generate DeviceTensor, but it would be missing + # _storage and causes error. + def func_src_op(*args, **kwargs): + return src_op(*args, **kwargs) + exported_f = dynamo.export( + func_src_op, + aten_graph=True, + decomposition_table={}, + constraints={}, + assume_static_by_default=True, + ) + gm, guards = exported_f(*flat_pytorch_args) + + # Setup mlir compilation pipeline. + session = Session() + session.set_flags(*DEFAULT_COMPILER_FLAGS) + session.set_flags("--iree-hal-target-backends=llvm-cpu") + context = session.context + + # Generate MLIR from FX. + importer = FxImporter(context=context) + module = importer.module + inv = session.invocation() + # TODO: Should capture diagnostics. + inv.enable_console_diagnostics() + inv.import_module(module.operation) + importer.import_graph_module(gm) + with context: + pm = PassManager.parse("builtin.module(torch-to-iree)") + pm.run(module.operation) + + # Compile MLIR to vmfb. + inv.execute() + output = Output.open_membuffer() + inv.output_vm_bytecode(output) + + # Map VMFB to buffer. + device_state = _get_device_state() + vmfb_module = VmModule.copy_buffer( + device_state.instance, + output.map_memory(), + ) + output.close() + + # Load and execute VMFB file. + exec = EagerSpecializedExecutable(vmfb_module, device_state) + res_host = exec(*py_args) + + TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash] = exec + + # Rewrap torch tensor into DeviceTensor and return. + # TODO: Handle multiple output. + # TODO: Refactor to not need to create new buffers every time once https://github.com/openxla/iree/pull/14997 lands. + dev_res = DeviceTensor._async_create_empty(res_host[0].size(), Device("local-task"), res_host[0].dtype) + dev_res._async_copy_from_host(res_host[0].numpy()) + return dev_res + + ############################################################################### # Conversions ############################################################################### diff --git a/tests/dynamo/tensor_test.py b/tests/dynamo/tensor_test.py index 0a499c32d..625d322e7 100644 --- a/tests/dynamo/tensor_test.py +++ b/tests/dynamo/tensor_test.py @@ -79,6 +79,54 @@ def test_factory_rand(self): t1 = torch.rand(4, device="turbine", dtype=torch.float32) print(t1.cpu()) + def test_binary_op(self): + t1 = 5.3 * torch.ones(2, 3).to(device="turbine") + t2 = 2.3 * torch.ones(2, 3).to(device="turbine") + t3 = t1 * t2 + np.testing.assert_allclose(t3.cpu(), [[12.19, 12.19, 12.19], [12.19, 12.19, 12.19]]) + + def test_unary_op(self): + t1 = -5.3 * torch.ones(2, 3).to(device="turbine") + t2 = torch.abs(t1) + np.testing.assert_allclose(t2.cpu(), [[5.3, 5.3, 5.3], [5.3, 5.3, 5.3]]) + + def test_nn_linear(self): + m = torch.nn.Linear(20, 30) + input = torch.randn(128, 20) + ref_output = m(input) + m.to("turbine") + input = input.to("turbine") + turbine_output = m(input) + np.testing.assert_allclose(turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6) + + def test_nn_MLP(self): + class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer0 = torch.nn.Linear(64, 32, bias=True) + self.layer1 = torch.nn.Linear(32, 16, bias=True) + self.layer2 = torch.nn.Linear(16, 7, bias=True) + self.layer3 = torch.nn.Linear(7, 7, bias=True) + + def forward(self, x: torch.Tensor): + x = self.layer0(x) + x = torch.sigmoid(x) + x = self.layer1(x) + x = torch.sigmoid(x) + x = self.layer2(x) + x = torch.sigmoid(x) + x = self.layer3(x) + return x + + m = MLP() + input = torch.randn(16, 64) + ref_output = m(input) + m.to("turbine") + input = input.to("turbine") + turbine_output = m(input) + import pdb; pdb.set_trace() + np.testing.assert_allclose(turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) From 0d17c92fa54ef489edfbb9a8440b9eb6dfb890b9 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Sun, 22 Oct 2023 11:31:31 -0700 Subject: [PATCH 2/4] Refactor device passing, and preserve device buffer --- python/shark_turbine/dynamo/device.py | 1 + python/shark_turbine/dynamo/executor.py | 27 +++++-- python/shark_turbine/dynamo/tensor.py | 43 +++++++--- tests/dynamo/tensor_example.py | 103 ++++++++++++++++++++++++ tests/dynamo/tensor_test.py | 1 - 5 files changed, 157 insertions(+), 18 deletions(-) create mode 100644 tests/dynamo/tensor_example.py diff --git a/python/shark_turbine/dynamo/device.py b/python/shark_turbine/dynamo/device.py index 07181e6a3..39947d633 100644 --- a/python/shark_turbine/dynamo/device.py +++ b/python/shark_turbine/dynamo/device.py @@ -7,6 +7,7 @@ from functools import lru_cache from typing import List, Optional, Sequence, Union from threading import local, Lock +import torch from iree.runtime import ( asdevicearray, diff --git a/python/shark_turbine/dynamo/executor.py b/python/shark_turbine/dynamo/executor.py index 83f6474ec..c80ab0ea3 100644 --- a/python/shark_turbine/dynamo/executor.py +++ b/python/shark_turbine/dynamo/executor.py @@ -6,11 +6,12 @@ import functools from typing import List, Optional, Sequence, Union - +from collections import namedtuple from iree.runtime import ( asdevicearray, create_hal_module, HalBufferView, + HalElementType, DeviceArray, get_driver, VmContext, @@ -21,6 +22,7 @@ VmVariantList, ) +import torch from torch import ( from_numpy as torch_from_numpy, ) @@ -32,6 +34,19 @@ def get_vm_instance() -> VmInstance: return VmInstance() +import numpy as np + +NUMPY_STR_TO_TORCH_DTYPE = { + np.dtypes.UInt8DType : torch.uint8, + np.dtypes.Int8DType : torch.int8, + np.dtypes.Int16DType : torch.int16, + np.dtypes.Int32DType : torch.int32, + np.dtypes.Int64DType : torch.int64, + np.dtypes.Float16DType : torch.float16, + np.dtypes.Float32DType : torch.float32, + np.dtypes.Float64DType : torch.float64, + np.dtypes.BoolDType : torch.bool, +} class SpecializedExecutable: """A concrete executable that has been specialized in some way.""" @@ -101,6 +116,7 @@ def _returns_to_user(self, ret_list: VmVariantList): return user_returns +EagerExecResult = namedtuple('EagerExecResult', ['buffer', 'size', 'dtype']) class EagerSpecializedExecutable: """A concrete executable that has been specialized in some way.""" @@ -159,8 +175,9 @@ def _returns_to_user(self, ret_list: VmVariantList): user_returns = [None] * num_returns for i in range(num_returns): device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i)) - device_array = DeviceArray(device, device_buffer_view) - host_array = device_array.to_host() - user_returns[i] = torch_from_numpy(host_array) - + npy_dtype = HalElementType.map_to_dtype(device_buffer_view.element_type) + size = torch.Size(device_buffer_view.shape) + dtype = NUMPY_STR_TO_TORCH_DTYPE[type(npy_dtype)] + device_buffer = device_buffer_view.get_buffer() + user_returns[i] = EagerExecResult(device_buffer, size, dtype) return user_returns diff --git a/python/shark_turbine/dynamo/tensor.py b/python/shark_turbine/dynamo/tensor.py index 5828ed2e1..996c5f001 100644 --- a/python/shark_turbine/dynamo/tensor.py +++ b/python/shark_turbine/dynamo/tensor.py @@ -273,7 +273,7 @@ def cpu(self): @property def device(self): - return "turbine" + return self._storage.device def __repr__(self): hal_device = self._storage.device.hal_device @@ -303,6 +303,14 @@ def _async_create_empty( storage.ready_fence.insert(*alloca_complete_semaphore) return DeviceTensor(size, dtype, raw_data=storage) + @staticmethod + def _from_buffer( + buffer: HalBuffer, size: Sequence[int], dtype: torch.dtype, device: Device + ) -> "DeviceTensor": + """Creates an uninitialized tensor with a given size and dtype.""" + storage = Storage(device, buffer) + return DeviceTensor(size, dtype, raw_data=storage) + def _async_fill_py_value(self, value): """Fills a value in all elements of the tensor. @@ -465,19 +473,32 @@ def compute_method(super_fn, *args, **kwargs): py_args = args[:-1] src_op = args[-1] - # Check if turbine and if all devices are the same. - py_args =[DeviceTensor.from_torch(torch.tensor(py_arg)) if isinstance(py_arg, (int, float)) else py_arg for py_arg in py_args] - devices = {tensor.device for tensor in py_args} - if "turbine" not in devices: + any_turbine_tensor = False + devices_set = set() + arg_shape_dtype_encode = [] + for arg_idx, py_arg in enumerate(py_args): + if isinstance(py_arg, DeviceTensor): + any_turbine_tensor = True + if isinstance(py_arg, (int, float)): + py_arg[arg_idx] = DeviceTensor.from_torch(torch.tensor(py_arg)) + devices_set.add(py_args[arg_idx].device) + arg_shape_dtype_encode.append(str(py_arg.shape) + str(py_arg.dtype)) + + # Check if turbine device exist. If doesn't run regular fn. + if not any_turbine_tensor: super_fn(*py_args, **kwargs) - if len(devices) > 1: + + # Do not support interop between Turbine and other devices. + if len(devices_set) > 1: raise ValueError("Turbine do not support mixed device!") + cur_device = py_args[0].device # Get a unique encoding to identify computation/dispatch using opCode, input shapes, and dtypes. - compute_id_encode = str(src_op) + "".join([str(py_arg.shape) + str(py_arg.dtype) for py_arg in py_args]) + compute_id_encode = src_op.name() + "".join(arg_shape_dtype_encode) compute_hash = hash(compute_id_encode) if compute_hash in TurbineMode.CACHED_IMPLEMENTATIONS: # TODO: Handle multiple output. - return DeviceTensor.from_torch(TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash](*py_args, **kwargs)[0]) + exec_res = TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash](*py_args, **kwargs)[0] + return DeviceTensor._from_buffer(exec_res.buffer, exec_res.size, exec_res.dtype, cur_device) # Preprocess func and generate into FX. flat_pytorch_args = [py_arg._to_meta_tensor() for py_arg in py_args] @@ -528,15 +549,13 @@ def func_src_op(*args, **kwargs): # Load and execute VMFB file. exec = EagerSpecializedExecutable(vmfb_module, device_state) - res_host = exec(*py_args) + exec_res = exec(*py_args)[0] TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash] = exec # Rewrap torch tensor into DeviceTensor and return. # TODO: Handle multiple output. - # TODO: Refactor to not need to create new buffers every time once https://github.com/openxla/iree/pull/14997 lands. - dev_res = DeviceTensor._async_create_empty(res_host[0].size(), Device("local-task"), res_host[0].dtype) - dev_res._async_copy_from_host(res_host[0].numpy()) + dev_res = DeviceTensor._from_buffer(exec_res.buffer, exec_res.size, exec_res.dtype, cur_device) return dev_res diff --git a/tests/dynamo/tensor_example.py b/tests/dynamo/tensor_example.py new file mode 100644 index 000000000..ab46a6820 --- /dev/null +++ b/tests/dynamo/tensor_example.py @@ -0,0 +1,103 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import time +import unittest + +import numpy as np +import torch +from viztracer import VizTracer + + +logging.basicConfig(level=logging.DEBUG) +# Public API imports. +from shark_turbine.dynamo import TurbineMode, enable, disable + +enable() + +def unary(): + t1 = -5*torch.ones(2, 3) + t1 = t1.to(device="turbine") + t2 = torch.abs(t1) + print(t2.cpu()) + return t2 + +def binary(): + t1 = 5*torch.ones(2, 3) + t1 = t1.to(device="turbine") + # mm = torch.matmul(t1, t2) + for _ in range(10): + t1 = t1 + 4 + print(t1.cpu()) + +def matmul(): + t1 = (5*torch.ones(2, 3)).to(device="turbine") + t2 = (3*torch.ones(3, 2)).to(device="turbine") + t3 = torch.matmul(t1, t2) + print(t3.cpu()) + return t3 + +class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer0 = torch.nn.Linear(64, 32, bias=True) + self.layer1 = torch.nn.Linear(32, 16, bias=True) + self.layer2 = torch.nn.Linear(16, 7, bias=True) + self.layer3 = torch.nn.Linear(7, 7, bias=True) + + def forward(self, x: torch.Tensor): + x = self.layer0(x) + x = torch.sigmoid(x) + x = self.layer1(x) + x = torch.sigmoid(x) + x = self.layer2(x) + x = torch.sigmoid(x) + x = self.layer3(x) + return x + +def MLP_run(): + m = MLP() + input = torch.randn(16, 64) + iter = 100 + start = time.time() + with torch.no_grad(): + for i in range(iter): + ref_out = m(input) + end = time.time() + print(f"Regular speed: {iter/(end-start)} it / sec") + print(ref_out) + m.to("turbine") + input = input.to("turbine") + turbine_output = m(input) + start = time.time() + # tracer = VizTracer() + with torch.no_grad(): + # tracer.start() + for i in range(iter): + turbine_output = m(input) + # tracer.stop() + end = time.time() + # tracer.save("turbine_run.json") + print(f"Turbine speed: {iter/(end-start)} it / sec") + print(turbine_output.cpu()) + return + +def linear(): + m = torch.nn.Linear(20, 30) + input = torch.randn(128, 20) + m.to("turbine") + d_input = input.to("turbine") + iter = 10 + start = time.time() + for i in range(iter): + output = m(d_input) + end = time.time() + print(f"{10/(end-start)} it / sec") + print(output.cpu()) + +if __name__ == "__main__": + MLP_run() diff --git a/tests/dynamo/tensor_test.py b/tests/dynamo/tensor_test.py index 625d322e7..b3906d4de 100644 --- a/tests/dynamo/tensor_test.py +++ b/tests/dynamo/tensor_test.py @@ -124,7 +124,6 @@ def forward(self, x: torch.Tensor): m.to("turbine") input = input.to("turbine") turbine_output = m(input) - import pdb; pdb.set_trace() np.testing.assert_allclose(turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6) From 3d9bf0072a4c51245ebf5f486887929f560836c9 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Sun, 22 Oct 2023 18:01:04 -0700 Subject: [PATCH 3/4] Add Async Execution support + Fix refactor to pass test. --- python/shark_turbine/aot/passes/functorch.py | 1 - python/shark_turbine/dynamo/__init__.py | 1 - python/shark_turbine/dynamo/device.py | 1 - python/shark_turbine/dynamo/executor.py | 101 +++++++++++++----- python/shark_turbine/dynamo/tensor.py | 87 +++++++++++----- tests/dynamo/tensor_example.py | 103 ------------------- tests/dynamo/tensor_test.py | 12 ++- 7 files changed, 147 insertions(+), 159 deletions(-) delete mode 100644 tests/dynamo/tensor_example.py diff --git a/python/shark_turbine/aot/passes/functorch.py b/python/shark_turbine/aot/passes/functorch.py index 7fbf067ce..52f3feb66 100644 --- a/python/shark_turbine/aot/passes/functorch.py +++ b/python/shark_turbine/aot/passes/functorch.py @@ -43,7 +43,6 @@ def functorch_functionalize(gm: GraphModule, *args) -> GraphModule: functionalized_callable = _functionalize_callabale(gm) # TODO: There is more of a dance needed if the user has entered with a fake_mode. - # import pdb; pdb.set_trace() with proxy_tensor.maybe_disable_fake_tensor_mode(): new_gm = proxy_tensor.make_fx( functionalized_callable, diff --git a/python/shark_turbine/dynamo/__init__.py b/python/shark_turbine/dynamo/__init__.py index eff76341d..aa6d60d96 100644 --- a/python/shark_turbine/dynamo/__init__.py +++ b/python/shark_turbine/dynamo/__init__.py @@ -8,7 +8,6 @@ from .device import Device from .tensor import ( enable, - disable, TurbineMode, DeviceTensor, ) diff --git a/python/shark_turbine/dynamo/device.py b/python/shark_turbine/dynamo/device.py index 39947d633..07181e6a3 100644 --- a/python/shark_turbine/dynamo/device.py +++ b/python/shark_turbine/dynamo/device.py @@ -7,7 +7,6 @@ from functools import lru_cache from typing import List, Optional, Sequence, Union from threading import local, Lock -import torch from iree.runtime import ( asdevicearray, diff --git a/python/shark_turbine/dynamo/executor.py b/python/shark_turbine/dynamo/executor.py index c80ab0ea3..92473fb49 100644 --- a/python/shark_turbine/dynamo/executor.py +++ b/python/shark_turbine/dynamo/executor.py @@ -5,12 +5,15 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import functools +import os from typing import List, Optional, Sequence, Union -from collections import namedtuple +from dataclasses import dataclass from iree.runtime import ( asdevicearray, create_hal_module, + HalBuffer, HalBufferView, + HalFence, HalElementType, DeviceArray, get_driver, @@ -27,27 +30,34 @@ from_numpy as torch_from_numpy, ) -from .device import DeviceState +from .device import Device, DeviceState + +turbine_exec_model = os.getenv("TURBINE_EXEC_MODEL", "default") @functools.lru_cache(maxsize=None) def get_vm_instance() -> VmInstance: return VmInstance() -import numpy as np - -NUMPY_STR_TO_TORCH_DTYPE = { - np.dtypes.UInt8DType : torch.uint8, - np.dtypes.Int8DType : torch.int8, - np.dtypes.Int16DType : torch.int16, - np.dtypes.Int32DType : torch.int32, - np.dtypes.Int64DType : torch.int64, - np.dtypes.Float16DType : torch.float16, - np.dtypes.Float32DType : torch.float32, - np.dtypes.Float64DType : torch.float64, - np.dtypes.BoolDType : torch.bool, + +_ELEMENT_TYPE_TO_DTYPE = { + HalElementType.FLOAT_16: torch.float16, + HalElementType.BFLOAT_16: torch.bfloat16, + HalElementType.FLOAT_32: torch.float32, + HalElementType.FLOAT_64: torch.float64, + HalElementType.UINT_8: torch.uint8, + HalElementType.SINT_8: torch.int8, + HalElementType.SINT_16: torch.int16, + HalElementType.SINT_32: torch.int32, + HalElementType.SINT_64: torch.int64, + HalElementType.BOOL_8: torch.bool, + HalElementType.OPAQUE_8: torch.qint8, + HalElementType.OPAQUE_8: torch.quint8, + HalElementType.COMPLEX_64: torch.complex64, + HalElementType.COMPLEX_128: torch.complex128, } + class SpecializedExecutable: """A concrete executable that has been specialized in some way.""" @@ -116,7 +126,20 @@ def _returns_to_user(self, ret_list: VmVariantList): return user_returns -EagerExecResult = namedtuple('EagerExecResult', ['buffer', 'size', 'dtype']) +@dataclass +class EagerExecResult: + buffer: HalBuffer + size: int + dtype: torch.dtype + signal: HalFence = None + + +def _element_type_to_dtype(element_type) -> torch.dtype: + try: + return _ELEMENT_TYPE_TO_DTYPE[element_type] + except KeyError: + raise ValueError(f"Unable to map {element_type} to torch dtype.") + class EagerSpecializedExecutable: """A concrete executable that has been specialized in some way.""" @@ -151,22 +174,36 @@ def __call__(self, *inputs): 1 ) # TODO: Get the number of results from the descriptor. + # Initialize wait and signal fence if not async mode. + device = inputs[0]._storage.device + wait_fence, signal_fence = self._initialize_fences(device, inputs, arg_list) + # Move inputs to the device and add to arguments. - self._inputs_to_device(inputs, arg_list) - # TODO: Append semaphores for async execution. + self._inputs_to_device(inputs, arg_list, wait_fence, signal_fence) # Invoke. self.vm_context.invoke(self.entry_function, arg_list, ret_list) - return self._returns_to_user(ret_list) + return self._returns_to_user(ret_list, signal_fence) - def _inputs_to_device(self, inputs: list, arg_list: VmVariantList): + def _inputs_to_device( + self, + inputs: list, + arg_list: VmVariantList, + wait_fence: HalFence = None, + signal_fence: HalFence = None, + ): # TODO: We are assuming the worst case here which is that we have unknown Torch # tensors that we send to the CPU and make continguous. Ideally, we would have # fast paths for our own backends and interop. for input in inputs: arg_list.push_ref(input.buffer_view) + wait_fence.extend(input._storage.ready_fence) - def _returns_to_user(self, ret_list: VmVariantList): + # Append fences into list. + arg_list.push_ref(wait_fence) + arg_list.push_ref(signal_fence) + + def _returns_to_user(self, ret_list: VmVariantList, signal: HalFence = None): # TODO: This is also not good that we are moving back to the CPU like this. # We should be returning a custom Tensor implementation which represents # our device data and has synchronization hooks for accessing it. @@ -175,9 +212,27 @@ def _returns_to_user(self, ret_list: VmVariantList): user_returns = [None] * num_returns for i in range(num_returns): device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i)) - npy_dtype = HalElementType.map_to_dtype(device_buffer_view.element_type) + dtype = _element_type_to_dtype(device_buffer_view.element_type) size = torch.Size(device_buffer_view.shape) - dtype = NUMPY_STR_TO_TORCH_DTYPE[type(npy_dtype)] device_buffer = device_buffer_view.get_buffer() - user_returns[i] = EagerExecResult(device_buffer, size, dtype) + user_returns[i] = EagerExecResult(device_buffer, size, dtype, signal) return user_returns + + def _initialize_fences(self, device: Device, inputs: list, arg_list: VmVariantList): + fence_capacity = device._fence_capacity + tx_semaphore = device._tx_timeline + current_tx_timepoint = device._tx_timepoint + + # Create wait semaphore and fence. + wait_semaphores = (tx_semaphore, current_tx_timepoint) + wait_fence = HalFence(fence_capacity) + wait_fence.insert(*wait_semaphores) + + # Create signal semaphore and fence. + device._tx_timepoint += 1 + signals_semaphore = (tx_semaphore, current_tx_timepoint + 1) + signal_fence = HalFence(fence_capacity) + signal_fence.insert(*signals_semaphore) + + # Add fences into arg_list for async exec. + return wait_fence, signal_fence diff --git a/python/shark_turbine/dynamo/tensor.py b/python/shark_turbine/dynamo/tensor.py index 996c5f001..a4ee3777d 100644 --- a/python/shark_turbine/dynamo/tensor.py +++ b/python/shark_turbine/dynamo/tensor.py @@ -13,10 +13,10 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple import functools -import sys import atexit from array import array import numpy as np +from types import BuiltinFunctionType import torch import torch._dynamo as dynamo @@ -26,9 +26,7 @@ Device, DeviceState, ) -from .executor import ( - EagerSpecializedExecutable, -) +from .executor import EagerSpecializedExecutable from ..support import ( ApiSequencingError, @@ -48,16 +46,13 @@ from iree.compiler.passmanager import PassManager from .importer import FxImporter -from .passes import turbine_cpu_pass_pipeline DEFAULT_COMPILER_FLAGS = ( # Enable asynchronous calling convention. - # TODO: Enable async execution mode. - # "--iree-execution-model=async-external", + "--iree-execution-model=async-external", "--iree-input-type=tm_tensor", ) - ############################################################################### # Factories and device enablement ############################################################################### @@ -80,6 +75,7 @@ def super_fn(*args, **kwargs): # the super() impl with torch._C.DisableTorchFunction(): return func(*args, **kwargs) + if func in self.IMPLEMENTATIONS: if func in self.COMPUTE_METHODS: args += (func,) @@ -95,10 +91,12 @@ def enable(): Device("local-task").set() atexit.register(disable) + def disable(): Device.current().clear() TurbineMode().__exit__(None, None, None) + # Convenient wrapper to register functions def raw_factory(func): """Decorator to register an unconditional factory function.""" @@ -109,6 +107,7 @@ def _inner_fn(impl): return _inner_fn + # Convenient wrapper to register functions def compute_factory(func): """Decorator to register an unconditional factory function.""" @@ -219,6 +218,7 @@ def __del__(self): class DeviceTensor(torch.Tensor): """A Tensor accessing memory on a Turbine device.""" + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs={}): # Now, we check the function to determine how to handle it. If it's # aten.add, then we call aten.sub. Otherwise, we pass through to @@ -248,10 +248,13 @@ def __init__(self, size, dtype, raw_data=None, requires_grad=False): raise NotImplementedError( f"raw_data= not implemented for DeviceTensor ({raw_data.__class__})" ) + @staticmethod - def from_torch(input_tensor : torch.Tensor): + def from_torch(input_tensor: torch.Tensor): if isinstance(input_tensor, torch.Tensor): - dev_tensor = DeviceTensor._async_create_empty(input_tensor.size(), Device("local-task"), input_tensor.dtype) + dev_tensor = DeviceTensor._async_create_empty( + input_tensor.size(), Device("local-task"), input_tensor.dtype + ) dev_tensor._async_copy_from_host(input_tensor.numpy()) return dev_tensor else: @@ -305,10 +308,16 @@ def _async_create_empty( @staticmethod def _from_buffer( - buffer: HalBuffer, size: Sequence[int], dtype: torch.dtype, device: Device + buffer: HalBuffer, + size: Sequence[int], + dtype: torch.dtype, + device: Device, + signal: HalFence, ) -> "DeviceTensor": """Creates an uninitialized tensor with a given size and dtype.""" storage = Storage(device, buffer) + if signal is not None: + storage.ready_fence = signal return DeviceTensor(size, dtype, raw_data=storage) def _async_fill_py_value(self, value): @@ -387,18 +396,20 @@ def to(super_fn, self, device, dtype=None, non_blocking=None): else: return super_fn(self, device) + @raw_factory(torch._C._nn._parse_to) def _parse_to(super_fn, *args, **kwargs): if "turbine" in args: - #TODO: Parse through args and kwargs for correct params. - device = "turbine" - dtype = None - non_blocking = False - convert_to_format = None - return device, dtype, non_blocking, convert_to_format + # TODO: Parse through args and kwargs for correct params. + device = "turbine" + dtype = None + non_blocking = False + convert_to_format = None + return device, dtype, non_blocking, convert_to_format else: return super_fn(self, device) + @device_factory(torch.empty) def _empty(*size, device: Device, dtype=torch.float32): # Turbine empty. @@ -459,6 +470,7 @@ def _rand(*args, dtype=None): def _get_device_state() -> DeviceState: return DeviceState(driver="local-task") + # Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/shark_turbine/aot/builtins/jittable.py#L212-L237 # and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/backends/cpu.py # TODO: Try to generalize for other devices. @@ -470,19 +482,22 @@ def compute_method(super_fn, *args, **kwargs): # Compute factory fns reserve the last arg as src_op # Requires src_op rather than super_fn, because super_fn # is often wrapped by DisableTorchFunction. - py_args = args[:-1] + init_py_args = args[:-1] src_op = args[-1] any_turbine_tensor = False devices_set = set() arg_shape_dtype_encode = [] - for arg_idx, py_arg in enumerate(py_args): + py_args = [] + for arg_idx, py_arg in enumerate(init_py_args): + ret_val = py_arg if isinstance(py_arg, DeviceTensor): any_turbine_tensor = True if isinstance(py_arg, (int, float)): - py_arg[arg_idx] = DeviceTensor.from_torch(torch.tensor(py_arg)) - devices_set.add(py_args[arg_idx].device) - arg_shape_dtype_encode.append(str(py_arg.shape) + str(py_arg.dtype)) + ret_val = DeviceTensor.from_torch(torch.tensor(py_arg)) + devices_set.add(ret_val.device) + arg_shape_dtype_encode.append(str(ret_val.shape) + str(ret_val.dtype)) + py_args.append(ret_val) # Check if turbine device exist. If doesn't run regular fn. if not any_turbine_tensor: @@ -493,20 +508,33 @@ def compute_method(super_fn, *args, **kwargs): raise ValueError("Turbine do not support mixed device!") cur_device = py_args[0].device # Get a unique encoding to identify computation/dispatch using opCode, input shapes, and dtypes. - compute_id_encode = src_op.name() + "".join(arg_shape_dtype_encode) + if isinstance(src_op, torch._ops.OpOverload): + src_op_name = src_op.name() + elif isinstance(src_op, BuiltinFunctionType): + src_op_name = src_op.__name__ + else: + raise ValueError("Expected srcOp to be torchOp or builtinFn.") + compute_id_encode = src_op_name + "".join(arg_shape_dtype_encode) compute_hash = hash(compute_id_encode) if compute_hash in TurbineMode.CACHED_IMPLEMENTATIONS: # TODO: Handle multiple output. - exec_res = TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash](*py_args, **kwargs)[0] - return DeviceTensor._from_buffer(exec_res.buffer, exec_res.size, exec_res.dtype, cur_device) + exec_res = TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash](*py_args, **kwargs)[ + 0 + ] + res_buf = DeviceTensor._from_buffer( + exec_res.buffer, exec_res.size, exec_res.dtype, cur_device, exec_res.signal + ) + return res_buf # Preprocess func and generate into FX. flat_pytorch_args = [py_arg._to_meta_tensor() for py_arg in py_args] + # TODO: Replace all the below with torch.compile, although currently seems like # the problem lies in it will try to generate DeviceTensor, but it would be missing # _storage and causes error. def func_src_op(*args, **kwargs): return src_op(*args, **kwargs) + exported_f = dynamo.export( func_src_op, aten_graph=True, @@ -549,13 +577,18 @@ def func_src_op(*args, **kwargs): # Load and execute VMFB file. exec = EagerSpecializedExecutable(vmfb_module, device_state) - exec_res = exec(*py_args)[0] + exec_results = exec(*py_args) + if len(exec_results) != 1: + raise ValueError("Currently only support one output for now.") + exec_res = exec_results[0] TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash] = exec # Rewrap torch tensor into DeviceTensor and return. # TODO: Handle multiple output. - dev_res = DeviceTensor._from_buffer(exec_res.buffer, exec_res.size, exec_res.dtype, cur_device) + dev_res = DeviceTensor._from_buffer( + exec_res.buffer, exec_res.size, exec_res.dtype, cur_device, exec_res.signal + ) return dev_res diff --git a/tests/dynamo/tensor_example.py b/tests/dynamo/tensor_example.py deleted file mode 100644 index ab46a6820..000000000 --- a/tests/dynamo/tensor_example.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import time -import unittest - -import numpy as np -import torch -from viztracer import VizTracer - - -logging.basicConfig(level=logging.DEBUG) -# Public API imports. -from shark_turbine.dynamo import TurbineMode, enable, disable - -enable() - -def unary(): - t1 = -5*torch.ones(2, 3) - t1 = t1.to(device="turbine") - t2 = torch.abs(t1) - print(t2.cpu()) - return t2 - -def binary(): - t1 = 5*torch.ones(2, 3) - t1 = t1.to(device="turbine") - # mm = torch.matmul(t1, t2) - for _ in range(10): - t1 = t1 + 4 - print(t1.cpu()) - -def matmul(): - t1 = (5*torch.ones(2, 3)).to(device="turbine") - t2 = (3*torch.ones(3, 2)).to(device="turbine") - t3 = torch.matmul(t1, t2) - print(t3.cpu()) - return t3 - -class MLP(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer0 = torch.nn.Linear(64, 32, bias=True) - self.layer1 = torch.nn.Linear(32, 16, bias=True) - self.layer2 = torch.nn.Linear(16, 7, bias=True) - self.layer3 = torch.nn.Linear(7, 7, bias=True) - - def forward(self, x: torch.Tensor): - x = self.layer0(x) - x = torch.sigmoid(x) - x = self.layer1(x) - x = torch.sigmoid(x) - x = self.layer2(x) - x = torch.sigmoid(x) - x = self.layer3(x) - return x - -def MLP_run(): - m = MLP() - input = torch.randn(16, 64) - iter = 100 - start = time.time() - with torch.no_grad(): - for i in range(iter): - ref_out = m(input) - end = time.time() - print(f"Regular speed: {iter/(end-start)} it / sec") - print(ref_out) - m.to("turbine") - input = input.to("turbine") - turbine_output = m(input) - start = time.time() - # tracer = VizTracer() - with torch.no_grad(): - # tracer.start() - for i in range(iter): - turbine_output = m(input) - # tracer.stop() - end = time.time() - # tracer.save("turbine_run.json") - print(f"Turbine speed: {iter/(end-start)} it / sec") - print(turbine_output.cpu()) - return - -def linear(): - m = torch.nn.Linear(20, 30) - input = torch.randn(128, 20) - m.to("turbine") - d_input = input.to("turbine") - iter = 10 - start = time.time() - for i in range(iter): - output = m(d_input) - end = time.time() - print(f"{10/(end-start)} it / sec") - print(output.cpu()) - -if __name__ == "__main__": - MLP_run() diff --git a/tests/dynamo/tensor_test.py b/tests/dynamo/tensor_test.py index b3906d4de..d736c153f 100644 --- a/tests/dynamo/tensor_test.py +++ b/tests/dynamo/tensor_test.py @@ -83,7 +83,9 @@ def test_binary_op(self): t1 = 5.3 * torch.ones(2, 3).to(device="turbine") t2 = 2.3 * torch.ones(2, 3).to(device="turbine") t3 = t1 * t2 - np.testing.assert_allclose(t3.cpu(), [[12.19, 12.19, 12.19], [12.19, 12.19, 12.19]]) + np.testing.assert_allclose( + t3.cpu(), [[12.19, 12.19, 12.19], [12.19, 12.19, 12.19]] + ) def test_unary_op(self): t1 = -5.3 * torch.ones(2, 3).to(device="turbine") @@ -97,7 +99,9 @@ def test_nn_linear(self): m.to("turbine") input = input.to("turbine") turbine_output = m(input) - np.testing.assert_allclose(turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6) + np.testing.assert_allclose( + turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6 + ) def test_nn_MLP(self): class MLP(torch.nn.Module): @@ -124,7 +128,9 @@ def forward(self, x: torch.Tensor): m.to("turbine") input = input.to("turbine") turbine_output = m(input) - np.testing.assert_allclose(turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6) + np.testing.assert_allclose( + turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6 + ) if __name__ == "__main__": From 5fd4c476200ef893775b4d1b8792579f41a61fe8 Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Fri, 27 Oct 2023 13:53:43 -0700 Subject: [PATCH 4/4] Refactoring and cleanups(tm input, compute_factory, kwargs, wrap_buffer) --- python/shark_turbine/dynamo/executor.py | 4 +--- python/shark_turbine/dynamo/tensor.py | 17 +++++------------ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/python/shark_turbine/dynamo/executor.py b/python/shark_turbine/dynamo/executor.py index 92473fb49..417946845 100644 --- a/python/shark_turbine/dynamo/executor.py +++ b/python/shark_turbine/dynamo/executor.py @@ -32,8 +32,6 @@ from .device import Device, DeviceState -turbine_exec_model = os.getenv("TURBINE_EXEC_MODEL", "default") - @functools.lru_cache(maxsize=None) def get_vm_instance() -> VmInstance: @@ -131,7 +129,7 @@ class EagerExecResult: buffer: HalBuffer size: int dtype: torch.dtype - signal: HalFence = None + signal: Optional[HalFence] = None def _element_type_to_dtype(element_type) -> torch.dtype: diff --git a/python/shark_turbine/dynamo/tensor.py b/python/shark_turbine/dynamo/tensor.py index a4ee3777d..809eae088 100644 --- a/python/shark_turbine/dynamo/tensor.py +++ b/python/shark_turbine/dynamo/tensor.py @@ -50,7 +50,7 @@ DEFAULT_COMPILER_FLAGS = ( # Enable asynchronous calling convention. "--iree-execution-model=async-external", - "--iree-input-type=tm_tensor", + "--iree-input-type=torch", ) ############################################################################### @@ -66,10 +66,10 @@ class TurbineMode(TorchFunctionMode): """ IMPLEMENTATIONS = {} - COMPUTE_METHODS = set() CACHED_IMPLEMENTATIONS = {} + COMPUTE_METHODS = set((torch.add, torch.sub, torch.mul, torch.abs)) - def __torch_function__(self, func, types, args=(), kwargs={}): + def __torch_function__(self, func, types, args=(), kwargs=None): def super_fn(*args, **kwargs): # Disable torch_function by hand because we don't want the wrapping behavior of # the super() impl @@ -474,10 +474,6 @@ def _get_device_state() -> DeviceState: # Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/shark_turbine/aot/builtins/jittable.py#L212-L237 # and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/backends/cpu.py # TODO: Try to generalize for other devices. -@compute_factory(torch.add) -@compute_factory(torch.sub) -@compute_factory(torch.mul) -@compute_factory(torch.abs) def compute_method(super_fn, *args, **kwargs): # Compute factory fns reserve the last arg as src_op # Requires src_op rather than super_fn, because super_fn @@ -558,9 +554,6 @@ def func_src_op(*args, **kwargs): inv.enable_console_diagnostics() inv.import_module(module.operation) importer.import_graph_module(gm) - with context: - pm = PassManager.parse("builtin.module(torch-to-iree)") - pm.run(module.operation) # Compile MLIR to vmfb. inv.execute() @@ -569,11 +562,11 @@ def func_src_op(*args, **kwargs): # Map VMFB to buffer. device_state = _get_device_state() - vmfb_module = VmModule.copy_buffer( + vmfb_module = VmModule.wrap_buffer( device_state.instance, output.map_memory(), + destroy_callback=output.close, ) - output.close() # Load and execute VMFB file. exec = EagerSpecializedExecutable(vmfb_module, device_state)