diff --git a/python/shark_turbine/dynamo/executor.py b/python/shark_turbine/dynamo/executor.py index 939ac1449..417946845 100644 --- a/python/shark_turbine/dynamo/executor.py +++ b/python/shark_turbine/dynamo/executor.py @@ -5,12 +5,16 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import functools +import os from typing import List, Optional, Sequence, Union - +from dataclasses import dataclass from iree.runtime import ( asdevicearray, create_hal_module, + HalBuffer, HalBufferView, + HalFence, + HalElementType, DeviceArray, get_driver, VmContext, @@ -21,11 +25,12 @@ VmVariantList, ) +import torch from torch import ( from_numpy as torch_from_numpy, ) -from .device import DeviceState +from .device import Device, DeviceState @functools.lru_cache(maxsize=None) @@ -33,6 +38,24 @@ def get_vm_instance() -> VmInstance: return VmInstance() +_ELEMENT_TYPE_TO_DTYPE = { + HalElementType.FLOAT_16: torch.float16, + HalElementType.BFLOAT_16: torch.bfloat16, + HalElementType.FLOAT_32: torch.float32, + HalElementType.FLOAT_64: torch.float64, + HalElementType.UINT_8: torch.uint8, + HalElementType.SINT_8: torch.int8, + HalElementType.SINT_16: torch.int16, + HalElementType.SINT_32: torch.int32, + HalElementType.SINT_64: torch.int64, + HalElementType.BOOL_8: torch.bool, + HalElementType.OPAQUE_8: torch.qint8, + HalElementType.OPAQUE_8: torch.quint8, + HalElementType.COMPLEX_64: torch.complex64, + HalElementType.COMPLEX_128: torch.complex128, +} + + class SpecializedExecutable: """A concrete executable that has been specialized in some way.""" @@ -99,3 +122,115 @@ def _returns_to_user(self, ret_list: VmVariantList): user_returns[i] = torch_from_numpy(host_array) return user_returns + + +@dataclass +class EagerExecResult: + buffer: HalBuffer + size: int + dtype: torch.dtype + signal: Optional[HalFence] = None + + +def _element_type_to_dtype(element_type) -> torch.dtype: + try: + return _ELEMENT_TYPE_TO_DTYPE[element_type] + except KeyError: + raise ValueError(f"Unable to map {element_type} to torch dtype.") + + +class EagerSpecializedExecutable: + """A concrete executable that has been specialized in some way.""" + + __slots__ = [ + "device_state", + "entry_function", + "user_module", + "vm_context", + ] + + def __init__( + self, + user_module: VmModule, + device_state: DeviceState, + entry_name: str = "main", + ): + self.user_module = user_module + self.vm_context = VmContext( + device_state.instance, + ( + create_hal_module(device_state.instance, device_state.device), + user_module, + ), + ) + self.device_state = device_state + self.entry_function = self.user_module.lookup_function(entry_name) + + def __call__(self, *inputs): + arg_list = VmVariantList(len(inputs)) + ret_list = VmVariantList( + 1 + ) # TODO: Get the number of results from the descriptor. + + # Initialize wait and signal fence if not async mode. + device = inputs[0]._storage.device + wait_fence, signal_fence = self._initialize_fences(device, inputs, arg_list) + + # Move inputs to the device and add to arguments. + self._inputs_to_device(inputs, arg_list, wait_fence, signal_fence) + + # Invoke. + self.vm_context.invoke(self.entry_function, arg_list, ret_list) + return self._returns_to_user(ret_list, signal_fence) + + def _inputs_to_device( + self, + inputs: list, + arg_list: VmVariantList, + wait_fence: HalFence = None, + signal_fence: HalFence = None, + ): + # TODO: We are assuming the worst case here which is that we have unknown Torch + # tensors that we send to the CPU and make continguous. Ideally, we would have + # fast paths for our own backends and interop. + for input in inputs: + arg_list.push_ref(input.buffer_view) + wait_fence.extend(input._storage.ready_fence) + + # Append fences into list. + arg_list.push_ref(wait_fence) + arg_list.push_ref(signal_fence) + + def _returns_to_user(self, ret_list: VmVariantList, signal: HalFence = None): + # TODO: This is also not good that we are moving back to the CPU like this. + # We should be returning a custom Tensor implementation which represents + # our device data and has synchronization hooks for accessing it. + device = self.device_state.device + num_returns = len(ret_list) + user_returns = [None] * num_returns + for i in range(num_returns): + device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i)) + dtype = _element_type_to_dtype(device_buffer_view.element_type) + size = torch.Size(device_buffer_view.shape) + device_buffer = device_buffer_view.get_buffer() + user_returns[i] = EagerExecResult(device_buffer, size, dtype, signal) + return user_returns + + def _initialize_fences(self, device: Device, inputs: list, arg_list: VmVariantList): + fence_capacity = device._fence_capacity + tx_semaphore = device._tx_timeline + current_tx_timepoint = device._tx_timepoint + + # Create wait semaphore and fence. + wait_semaphores = (tx_semaphore, current_tx_timepoint) + wait_fence = HalFence(fence_capacity) + wait_fence.insert(*wait_semaphores) + + # Create signal semaphore and fence. + device._tx_timepoint += 1 + signals_semaphore = (tx_semaphore, current_tx_timepoint + 1) + signal_fence = HalFence(fence_capacity) + signal_fence.insert(*signals_semaphore) + + # Add fences into arg_list for async exec. + return wait_fence, signal_fence diff --git a/python/shark_turbine/dynamo/tensor.py b/python/shark_turbine/dynamo/tensor.py index 859b379cb..809eae088 100644 --- a/python/shark_turbine/dynamo/tensor.py +++ b/python/shark_turbine/dynamo/tensor.py @@ -12,15 +12,21 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple +import functools +import atexit from array import array import numpy as np +from types import BuiltinFunctionType import torch +import torch._dynamo as dynamo from torch.overrides import TorchFunctionMode from .device import ( Device, + DeviceState, ) +from .executor import EagerSpecializedExecutable from ..support import ( ApiSequencingError, @@ -33,8 +39,19 @@ HalCommandBuffer, HalElementType, HalFence, + VmModule, ) +from iree.compiler.api import Session, Output +from iree.compiler.passmanager import PassManager + +from .importer import FxImporter + +DEFAULT_COMPILER_FLAGS = ( + # Enable asynchronous calling convention. + "--iree-execution-model=async-external", + "--iree-input-type=torch", +) ############################################################################### # Factories and device enablement @@ -49,6 +66,8 @@ class TurbineMode(TorchFunctionMode): """ IMPLEMENTATIONS = {} + CACHED_IMPLEMENTATIONS = {} + COMPUTE_METHODS = set((torch.add, torch.sub, torch.mul, torch.abs)) def __torch_function__(self, func, types, args=(), kwargs=None): def super_fn(*args, **kwargs): @@ -58,6 +77,8 @@ def super_fn(*args, **kwargs): return func(*args, **kwargs) if func in self.IMPLEMENTATIONS: + if func in self.COMPUTE_METHODS: + args += (func,) return self.IMPLEMENTATIONS[func](super_fn, *args, **kwargs or {}) # This is just a no-op for all the non-factory functions: @@ -67,6 +88,13 @@ def super_fn(*args, **kwargs): def enable(): """Enables PyTorch tensor device= support for Turbine permanently.""" TurbineMode().__enter__() + Device("local-task").set() + atexit.register(disable) + + +def disable(): + Device.current().clear() + TurbineMode().__exit__(None, None, None) # Convenient wrapper to register functions @@ -80,6 +108,18 @@ def _inner_fn(impl): return _inner_fn +# Convenient wrapper to register functions +def compute_factory(func): + """Decorator to register an unconditional factory function.""" + + def _inner_fn(impl): + TurbineMode.IMPLEMENTATIONS[func] = impl + TurbineMode.COMPUTE_METHODS.add(func) + return impl + + return _inner_fn + + def device_factory(func): """Decorator to invoke the user provided factory for our devices. @@ -178,6 +218,17 @@ def __del__(self): class DeviceTensor(torch.Tensor): """A Tensor accessing memory on a Turbine device.""" + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs={}): + # Now, we check the function to determine how to handle it. If it's + # aten.add, then we call aten.sub. Otherwise, we pass through to + # the original function + args += (func,) + return compute_method(func, *args, **kwargs) + + def _to_meta_tensor(self): + return torch.empty(self.shape, dtype=self.dtype) + @staticmethod def __new__(cls, size, dtype, raw_data=None, requires_grad=False): # Using a meta tensor as the wrapped gives us shape and dtype @@ -198,6 +249,18 @@ def __init__(self, size, dtype, raw_data=None, requires_grad=False): f"raw_data= not implemented for DeviceTensor ({raw_data.__class__})" ) + @staticmethod + def from_torch(input_tensor: torch.Tensor): + if isinstance(input_tensor, torch.Tensor): + dev_tensor = DeviceTensor._async_create_empty( + input_tensor.size(), Device("local-task"), input_tensor.dtype + ) + dev_tensor._async_copy_from_host(input_tensor.numpy()) + return dev_tensor + else: + if input_tensor is not None: + raise ValueError("Expected input to be of type torch.Tensor.") + @property def buffer_view(self) -> HalBufferView: if self._bv is None: @@ -211,6 +274,10 @@ def buffer_view(self) -> HalBufferView: def cpu(self): return self.to("cpu") + @property + def device(self): + return self._storage.device + def __repr__(self): hal_device = self._storage.device.hal_device try: @@ -239,6 +306,20 @@ def _async_create_empty( storage.ready_fence.insert(*alloca_complete_semaphore) return DeviceTensor(size, dtype, raw_data=storage) + @staticmethod + def _from_buffer( + buffer: HalBuffer, + size: Sequence[int], + dtype: torch.dtype, + device: Device, + signal: HalFence, + ) -> "DeviceTensor": + """Creates an uninitialized tensor with a given size and dtype.""" + storage = Storage(device, buffer) + if signal is not None: + storage.ready_fence = signal + return DeviceTensor(size, dtype, raw_data=storage) + def _async_fill_py_value(self, value): """Fills a value in all elements of the tensor. @@ -289,7 +370,7 @@ def _calculate_c_contig_size(size: Sequence[int], dtype: torch.dtype) -> int: # And some factory functions # By hand @raw_factory(torch.Tensor.to) -def to(super_fn, self, device): +def to(super_fn, self, device, dtype=None, non_blocking=None): # Note that we only implement a subset of .to() here turbine_device = _parse_device(device) if turbine_device: @@ -316,6 +397,19 @@ def to(super_fn, self, device): return super_fn(self, device) +@raw_factory(torch._C._nn._parse_to) +def _parse_to(super_fn, *args, **kwargs): + if "turbine" in args: + # TODO: Parse through args and kwargs for correct params. + device = "turbine" + dtype = None + non_blocking = False + convert_to_format = None + return device, dtype, non_blocking, convert_to_format + else: + return super_fn(self, device) + + @device_factory(torch.empty) def _empty(*size, device: Device, dtype=torch.float32): # Turbine empty. @@ -372,6 +466,125 @@ def _rand(*args, dtype=None): return t +@functools.lru_cache(maxsize=None) +def _get_device_state() -> DeviceState: + return DeviceState(driver="local-task") + + +# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/shark_turbine/aot/builtins/jittable.py#L212-L237 +# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/backends/cpu.py +# TODO: Try to generalize for other devices. +def compute_method(super_fn, *args, **kwargs): + # Compute factory fns reserve the last arg as src_op + # Requires src_op rather than super_fn, because super_fn + # is often wrapped by DisableTorchFunction. + init_py_args = args[:-1] + src_op = args[-1] + + any_turbine_tensor = False + devices_set = set() + arg_shape_dtype_encode = [] + py_args = [] + for arg_idx, py_arg in enumerate(init_py_args): + ret_val = py_arg + if isinstance(py_arg, DeviceTensor): + any_turbine_tensor = True + if isinstance(py_arg, (int, float)): + ret_val = DeviceTensor.from_torch(torch.tensor(py_arg)) + devices_set.add(ret_val.device) + arg_shape_dtype_encode.append(str(ret_val.shape) + str(ret_val.dtype)) + py_args.append(ret_val) + + # Check if turbine device exist. If doesn't run regular fn. + if not any_turbine_tensor: + super_fn(*py_args, **kwargs) + + # Do not support interop between Turbine and other devices. + if len(devices_set) > 1: + raise ValueError("Turbine do not support mixed device!") + cur_device = py_args[0].device + # Get a unique encoding to identify computation/dispatch using opCode, input shapes, and dtypes. + if isinstance(src_op, torch._ops.OpOverload): + src_op_name = src_op.name() + elif isinstance(src_op, BuiltinFunctionType): + src_op_name = src_op.__name__ + else: + raise ValueError("Expected srcOp to be torchOp or builtinFn.") + compute_id_encode = src_op_name + "".join(arg_shape_dtype_encode) + compute_hash = hash(compute_id_encode) + if compute_hash in TurbineMode.CACHED_IMPLEMENTATIONS: + # TODO: Handle multiple output. + exec_res = TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash](*py_args, **kwargs)[ + 0 + ] + res_buf = DeviceTensor._from_buffer( + exec_res.buffer, exec_res.size, exec_res.dtype, cur_device, exec_res.signal + ) + return res_buf + + # Preprocess func and generate into FX. + flat_pytorch_args = [py_arg._to_meta_tensor() for py_arg in py_args] + + # TODO: Replace all the below with torch.compile, although currently seems like + # the problem lies in it will try to generate DeviceTensor, but it would be missing + # _storage and causes error. + def func_src_op(*args, **kwargs): + return src_op(*args, **kwargs) + + exported_f = dynamo.export( + func_src_op, + aten_graph=True, + decomposition_table={}, + constraints={}, + assume_static_by_default=True, + ) + gm, guards = exported_f(*flat_pytorch_args) + + # Setup mlir compilation pipeline. + session = Session() + session.set_flags(*DEFAULT_COMPILER_FLAGS) + session.set_flags("--iree-hal-target-backends=llvm-cpu") + context = session.context + + # Generate MLIR from FX. + importer = FxImporter(context=context) + module = importer.module + inv = session.invocation() + # TODO: Should capture diagnostics. + inv.enable_console_diagnostics() + inv.import_module(module.operation) + importer.import_graph_module(gm) + + # Compile MLIR to vmfb. + inv.execute() + output = Output.open_membuffer() + inv.output_vm_bytecode(output) + + # Map VMFB to buffer. + device_state = _get_device_state() + vmfb_module = VmModule.wrap_buffer( + device_state.instance, + output.map_memory(), + destroy_callback=output.close, + ) + + # Load and execute VMFB file. + exec = EagerSpecializedExecutable(vmfb_module, device_state) + exec_results = exec(*py_args) + if len(exec_results) != 1: + raise ValueError("Currently only support one output for now.") + exec_res = exec_results[0] + + TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash] = exec + + # Rewrap torch tensor into DeviceTensor and return. + # TODO: Handle multiple output. + dev_res = DeviceTensor._from_buffer( + exec_res.buffer, exec_res.size, exec_res.dtype, cur_device, exec_res.signal + ) + return dev_res + + ############################################################################### # Conversions ############################################################################### diff --git a/tests/dynamo/tensor_test.py b/tests/dynamo/tensor_test.py index 0a499c32d..d736c153f 100644 --- a/tests/dynamo/tensor_test.py +++ b/tests/dynamo/tensor_test.py @@ -79,6 +79,59 @@ def test_factory_rand(self): t1 = torch.rand(4, device="turbine", dtype=torch.float32) print(t1.cpu()) + def test_binary_op(self): + t1 = 5.3 * torch.ones(2, 3).to(device="turbine") + t2 = 2.3 * torch.ones(2, 3).to(device="turbine") + t3 = t1 * t2 + np.testing.assert_allclose( + t3.cpu(), [[12.19, 12.19, 12.19], [12.19, 12.19, 12.19]] + ) + + def test_unary_op(self): + t1 = -5.3 * torch.ones(2, 3).to(device="turbine") + t2 = torch.abs(t1) + np.testing.assert_allclose(t2.cpu(), [[5.3, 5.3, 5.3], [5.3, 5.3, 5.3]]) + + def test_nn_linear(self): + m = torch.nn.Linear(20, 30) + input = torch.randn(128, 20) + ref_output = m(input) + m.to("turbine") + input = input.to("turbine") + turbine_output = m(input) + np.testing.assert_allclose( + turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6 + ) + + def test_nn_MLP(self): + class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer0 = torch.nn.Linear(64, 32, bias=True) + self.layer1 = torch.nn.Linear(32, 16, bias=True) + self.layer2 = torch.nn.Linear(16, 7, bias=True) + self.layer3 = torch.nn.Linear(7, 7, bias=True) + + def forward(self, x: torch.Tensor): + x = self.layer0(x) + x = torch.sigmoid(x) + x = self.layer1(x) + x = torch.sigmoid(x) + x = self.layer2(x) + x = torch.sigmoid(x) + x = self.layer3(x) + return x + + m = MLP() + input = torch.randn(16, 64) + ref_output = m(input) + m.to("turbine") + input = input.to("turbine") + turbine_output = m(input) + np.testing.assert_allclose( + turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6 + ) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)