diff --git a/python/shark_turbine/aot/support/ir_utils.py b/python/shark_turbine/aot/support/ir_utils.py index 5e1daf3ed..6290297e5 100644 --- a/python/shark_turbine/aot/support/ir_utils.py +++ b/python/shark_turbine/aot/support/ir_utils.py @@ -265,7 +265,9 @@ def create_tensor_global( array = np.array(detached_tensor) # We know that a Numpy array is a ReadableBuffer so ignore type error. contents = memoryview(array) # type: ignore - elements_attr = DenseResourceElementsAttr.get_from_buffer(contents, "from_py", tensor_type) + elements_attr = DenseResourceElementsAttr.get_from_buffer( + contents, "from_py", tensor_type + ) ir_attrs["initial_value"] = elements_attr global_op = Operation.create("util.global", attributes=ir_attrs) diff --git a/python/shark_turbine/aot/support/procedural/globals.py b/python/shark_turbine/aot/support/procedural/globals.py index be882eb59..cb6781867 100644 --- a/python/shark_turbine/aot/support/procedural/globals.py +++ b/python/shark_turbine/aot/support/procedural/globals.py @@ -141,10 +141,7 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any: continue elif isinstance(value, AbstractTensor): global_type = value.get_ir_type(module_builder) - ( - actual_symbol_name, - global_op, - ) = module_builder.create_typed_global( + (actual_symbol_name, global_op,) = module_builder.create_typed_global( f"_{fq_name}", global_type, attrs=self._attrs, @@ -163,10 +160,7 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any: continue elif isinstance(value, AbstractScalar): global_type = value.get_ir_type(module_builder) - ( - actual_symbol_name, - global_op, - ) = module_builder.create_typed_global( + (actual_symbol_name, global_op,) = module_builder.create_typed_global( f"_{fq_name}", global_type, attrs=self._attrs, diff --git a/python/shark_turbine/dynamo/importer.py b/python/shark_turbine/dynamo/importer.py index a621e18f6..998d2f850 100644 --- a/python/shark_turbine/dynamo/importer.py +++ b/python/shark_turbine/dynamo/importer.py @@ -377,9 +377,9 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType: val = node.meta.get("val") if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) - # Quantized tensor meta data is not preserved in our lowering, + # Quantized tensor meta data is not preserved in our lowering, # so throw error instead of silently doing wrong thing. - if (tensor_meta.is_quantized): + if tensor_meta.is_quantized: raise NotImplementedError( f"Quantized tensor meta data is not supported." ) @@ -700,7 +700,7 @@ def _import_torch_op_overload( ) # Support unregistered torch ops using torch.operator. - # torch.operator is used to represent ops from registry + # torch.operator is used to represent ops from registry # which haven't been generated by torch_ods_gen.py. if not self._c.is_registered_operation(mlir_op_name): operation = Operation.create( @@ -860,7 +860,7 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: return self._import_list_argument(loc, arg, expected_jit_type) # The LITERAL_CONVERTER_MAP maps each arg to its respective constant - # of the expected jit IR type (types like torch.dtype will form a chain of + # of the expected jit IR type (types like torch.dtype will form a chain of # maps to get to constant of expected_jit_type). cvt = LITERAL_CONVERTER_MAP.lookup(type(arg)) if cvt is None: @@ -916,13 +916,13 @@ def _make_constant_op( def create_mlir_tensor_type(tensor: torch.Tensor) -> MlirType: - try: - dtype = tensor.dtype - element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() - tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type) - return tensor_type - except KeyError: - raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type") + try: + dtype = tensor.dtype + element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() + tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type) + return tensor_type + except KeyError: + raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type") def _make_vtensor_literal_op(tensor: torch.Tensor, vtensor_type: MlirType) -> Operation: @@ -939,7 +939,9 @@ def _make_vtensor_literal_op(tensor: torch.Tensor, vtensor_type: MlirType) -> Op np_tensor = np.array(tensor.tolist()).astype(npy_dtype) bytes = memoryview(np_tensor) tensor_type = create_mlir_tensor_type(tensor) - elements_attr = DenseResourceElementsAttr.get_from_buffer(bytes, "from_py", tensor_type) + elements_attr = DenseResourceElementsAttr.get_from_buffer( + bytes, "from_py", tensor_type + ) return Operation.create( name="torch.vtensor.literal", results=[vtensor_type], diff --git a/python/shark_turbine/transforms/general/rename_parameters.py b/python/shark_turbine/transforms/general/rename_parameters.py index 86d5c0f04..be263b92a 100644 --- a/python/shark_turbine/transforms/general/rename_parameters.py +++ b/python/shark_turbine/transforms/general/rename_parameters.py @@ -38,7 +38,9 @@ def __init__( root_op: Operation, *, rename_map: Optional[Dict[MaybeScopedName, MaybeScopedName]] = None, - rename_callback: Callable[[Optional[str], str], Optional[ScopedName]] = lambda scope, name: None + rename_callback: Callable[ + [Optional[str], str], Optional[ScopedName] + ] = lambda scope, name: None, ): super().__init__(root_op) self.context = root_op.context @@ -82,22 +84,26 @@ def norm_map_result(result: MaybeScopedName) -> ScopedName: return orig_scope, result[0] else: return result[0], result[1] - + def make_attr(scoped_name: ScopedName) -> Attribute: if scoped_name[0] is None: name = StringAttr.get(scoped_name[1]) - return Attribute.parse(f"#stream.parameter.named<{name}> : {parameter_attr.type}") + return Attribute.parse( + f"#stream.parameter.named<{name}> : {parameter_attr.type}" + ) else: scope = StringAttr.get(scoped_name[0]) name = StringAttr.get(scoped_name[1]) - return Attribute.parse(f"#stream.parameter.named<{scope}::{name}> : {parameter_attr.type}") - + return Attribute.parse( + f"#stream.parameter.named<{scope}::{name}> : {parameter_attr.type}" + ) + # Check the rename map. # Check with a fully-qualified name. result = self.rename_map.get((orig_scope, name)) if result is not None: return make_attr(norm_map_result(result)) - # Check with just the + # Check with just the result = self.rename_map.get(name) if result is not None: return make_attr(norm_map_result(result)) diff --git a/python/turbine_models/custom_models/remap_gguf.py b/python/turbine_models/custom_models/remap_gguf.py index 28fb85bd2..30cabcad7 100644 --- a/python/turbine_models/custom_models/remap_gguf.py +++ b/python/turbine_models/custom_models/remap_gguf.py @@ -412,9 +412,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): continue gguf_tensor_name = TENSOR_NAMES[tensor] if MODEL_ARCH_NAMES[arch] in tensor_dict: - self.mapping[ - tensor_dict[MODEL_ARCH_NAMES[arch]] - ] = gguf_tensor_name + self.mapping[tensor_dict[MODEL_ARCH_NAMES[arch]]] = gguf_tensor_name for bid in range(n_blocks): for tensor, tensor_dict in self.block_mappings_cfg.items(): if tensor not in MODEL_TENSORS[arch]: diff --git a/tests/dynamo/importer_basic_test.py b/tests/dynamo/importer_basic_test.py index 25ea41212..6b57b95c0 100644 --- a/tests/dynamo/importer_basic_test.py +++ b/tests/dynamo/importer_basic_test.py @@ -179,7 +179,7 @@ def foo(): f16 = torch.tensor([1.1, 2.2, 3.3, 4.4], dtype=torch.float16) f32 = torch.tensor([1.1, 2.2, 3.3, 4.4], dtype=torch.float32) return f16, f32 - + opt_foo = torch.compile(foo, backend=create_backend()) opt_foo() diff --git a/tests/dynamo/importer_dynamic_test.py b/tests/dynamo/importer_dynamic_test.py index 622228864..c9fff7640 100644 --- a/tests/dynamo/importer_dynamic_test.py +++ b/tests/dynamo/importer_dynamic_test.py @@ -97,13 +97,16 @@ def forward(self, inp): g = x / 32 return {"result": g} + class DynamicShapeStridedModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, a): dynamic_shape = [a.size(0), a.size(1), a.size(2)] - x = torch.ops.aten.empty_strided(dynamic_shape, stride=[12, 4, 1]) # Default stride = [12, 4, 1] + x = torch.ops.aten.empty_strided( + dynamic_shape, stride=[12, 4, 1] + ) # Default stride = [12, 4, 1] y = x.copy_(a) return y @@ -173,7 +176,7 @@ def testDynamicShapeStrided(self): """ model = DynamicShapeStridedModule() # inp_example = torch.rand(5, 7, 9) - inp_example = torch.randn(2, 3, 4) # input for default stride + inp_example = torch.randn(2, 3, 4) # input for default stride f = dynamo.export( model.forward, aten_graph=True, @@ -186,6 +189,7 @@ def testDynamicShapeStrided(self): g, guards = f(a=inp_example) g = import_compiler(g, [inp_example]) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/tests/dynamo/mninst_test.py b/tests/dynamo/mninst_test.py index 65ad254cd..88742e2bd 100644 --- a/tests/dynamo/mninst_test.py +++ b/tests/dynamo/mninst_test.py @@ -16,7 +16,9 @@ import torch._dynamo.config -torch._dynamo.config.dynamic_shapes = False # TODO: https://github.com/nod-ai/SHARK-Turbine/issues/93 +torch._dynamo.config.dynamic_shapes = ( + False # TODO: https://github.com/nod-ai/SHARK-Turbine/issues/93 +) class MNISTDataLoader: @@ -44,7 +46,10 @@ def get_train_loader(self): def get_test_loader(self): return DataLoader( - dataset=self.mnist_testset, batch_size=self.batch_size, shuffle=False, drop_last=True, + dataset=self.mnist_testset, + batch_size=self.batch_size, + shuffle=False, + drop_last=True, ) diff --git a/tests/examples/aot_mlp_test.py b/tests/examples/aot_mlp_test.py index 8b1380604..c4266a4af 100644 --- a/tests/examples/aot_mlp_test.py +++ b/tests/examples/aot_mlp_test.py @@ -17,6 +17,7 @@ def _run(local_path: str): path = REPO_DIR / local_path subprocess.check_call([sys.executable, str(path)]) + class AOTMLPTest(unittest.TestCase): def testMLPExportSimple(self): _run("examples/aot_mlp/mlp_export_simple.py") @@ -24,6 +25,7 @@ def testMLPExportSimple(self): def testMLPExportSimple(self): _run("examples/aot_mlp/mlp_export_dynamic.py") + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/tests/models/hf_dict.py b/tests/models/hf_dict.py index 1caebc520..d0071e06e 100644 --- a/tests/models/hf_dict.py +++ b/tests/models/hf_dict.py @@ -1,15 +1,17 @@ import torch -class ModelData(): + +class ModelData: def __init__(self, input_shape, torch_dtype: torch.dtype, xfail: bool = True): self.xfail = xfail self.input_shape = input_shape self.dtype = torch_dtype + model_dict = { - 'distilgpt2': ModelData(input_shape=(1, 1), torch_dtype=torch.int64), - 'gpt2': ModelData(input_shape=(1, 1), torch_dtype=torch.int64), - 'gpt2-medium': ModelData(input_shape=(1, 1), torch_dtype=torch.int64), - 'bert-base-uncased': ModelData(input_shape=(1, 1), torch_dtype=torch.int64), - 'bert-large-uncased': ModelData(input_shape=(1, 1), torch_dtype=torch.int64), -} \ No newline at end of file + "distilgpt2": ModelData(input_shape=(1, 1), torch_dtype=torch.int64), + "gpt2": ModelData(input_shape=(1, 1), torch_dtype=torch.int64), + "gpt2-medium": ModelData(input_shape=(1, 1), torch_dtype=torch.int64), + "bert-base-uncased": ModelData(input_shape=(1, 1), torch_dtype=torch.int64), + "bert-large-uncased": ModelData(input_shape=(1, 1), torch_dtype=torch.int64), +} diff --git a/tests/models/transformer_builder_tests.py b/tests/models/transformer_builder_tests.py index 81165581f..0c45f0c7d 100644 --- a/tests/models/transformer_builder_tests.py +++ b/tests/models/transformer_builder_tests.py @@ -3,21 +3,29 @@ from hf_dict import model_dict import torch + class TestHFTransformerBuilder(unittest.TestCase): pass + def create_test(model_name, model_data): def test(self): - example_input = torch.ones(*model_data.input_shape, dtype=model_data.torch_dtype) + example_input = torch.ones( + *model_data.input_shape, dtype=model_data.torch_dtype + ) builder = HFTransformerBuilder(example_input, model_name, auto_tokenizer=None) compiled_module = builder.get_compiled_module() self.assertIsNotNone(compiled_module) + return test + for model_name, model_data in model_dict.items(): test_method = create_test(model_name, model_data) - test_method = unittest.expectedFailure(test_method) if model_data.xfail else test_method - setattr(TestHFTransformerBuilder, f'test_{model_name}', test_method) + test_method = ( + unittest.expectedFailure(test_method) if model_data.xfail else test_method + ) + setattr(TestHFTransformerBuilder, f"test_{model_name}", test_method) -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main()