Skip to content

Commit

Permalink
Run black.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Nov 15, 2023
1 parent d32c8b6 commit 4117974
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 47 deletions.
4 changes: 3 additions & 1 deletion python/shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def create_tensor_global(
array = np.array(detached_tensor)
# We know that a Numpy array is a ReadableBuffer so ignore type error.
contents = memoryview(array) # type: ignore
elements_attr = DenseResourceElementsAttr.get_from_buffer(contents, "from_py", tensor_type)
elements_attr = DenseResourceElementsAttr.get_from_buffer(
contents, "from_py", tensor_type
)
ir_attrs["initial_value"] = elements_attr

global_op = Operation.create("util.global", attributes=ir_attrs)
Expand Down
10 changes: 2 additions & 8 deletions python/shark_turbine/aot/support/procedural/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,7 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractTensor):
global_type = value.get_ir_type(module_builder)
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
attrs=self._attrs,
Expand All @@ -163,10 +160,7 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractScalar):
global_type = value.get_ir_type(module_builder)
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
attrs=self._attrs,
Expand Down
26 changes: 14 additions & 12 deletions python/shark_turbine/dynamo/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType:
val = node.meta.get("val")
if tensor_meta is not None:
assert isinstance(tensor_meta, TensorMetadata)
# Quantized tensor meta data is not preserved in our lowering,
# Quantized tensor meta data is not preserved in our lowering,
# so throw error instead of silently doing wrong thing.
if (tensor_meta.is_quantized):
if tensor_meta.is_quantized:
raise NotImplementedError(
f"Quantized tensor meta data is not supported."
)
Expand Down Expand Up @@ -700,7 +700,7 @@ def _import_torch_op_overload(
)

# Support unregistered torch ops using torch.operator.
# torch.operator is used to represent ops from registry
# torch.operator is used to represent ops from registry
# which haven't been generated by torch_ods_gen.py.
if not self._c.is_registered_operation(mlir_op_name):
operation = Operation.create(
Expand Down Expand Up @@ -860,7 +860,7 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value:
return self._import_list_argument(loc, arg, expected_jit_type)

# The LITERAL_CONVERTER_MAP maps each arg to its respective constant
# of the expected jit IR type (types like torch.dtype will form a chain of
# of the expected jit IR type (types like torch.dtype will form a chain of
# maps to get to constant of expected_jit_type).
cvt = LITERAL_CONVERTER_MAP.lookup(type(arg))
if cvt is None:
Expand Down Expand Up @@ -916,13 +916,13 @@ def _make_constant_op(


def create_mlir_tensor_type(tensor: torch.Tensor) -> MlirType:
try:
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
return tensor_type
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type")
try:
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
return tensor_type
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type")


def _make_vtensor_literal_op(tensor: torch.Tensor, vtensor_type: MlirType) -> Operation:
Expand All @@ -939,7 +939,9 @@ def _make_vtensor_literal_op(tensor: torch.Tensor, vtensor_type: MlirType) -> Op
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
bytes = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
elements_attr = DenseResourceElementsAttr.get_from_buffer(bytes, "from_py", tensor_type)
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes, "from_py", tensor_type
)
return Operation.create(
name="torch.vtensor.literal",
results=[vtensor_type],
Expand Down
18 changes: 12 additions & 6 deletions python/shark_turbine/transforms/general/rename_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __init__(
root_op: Operation,
*,
rename_map: Optional[Dict[MaybeScopedName, MaybeScopedName]] = None,
rename_callback: Callable[[Optional[str], str], Optional[ScopedName]] = lambda scope, name: None
rename_callback: Callable[
[Optional[str], str], Optional[ScopedName]
] = lambda scope, name: None,
):
super().__init__(root_op)
self.context = root_op.context
Expand Down Expand Up @@ -82,22 +84,26 @@ def norm_map_result(result: MaybeScopedName) -> ScopedName:
return orig_scope, result[0]
else:
return result[0], result[1]

def make_attr(scoped_name: ScopedName) -> Attribute:
if scoped_name[0] is None:
name = StringAttr.get(scoped_name[1])
return Attribute.parse(f"#stream.parameter.named<{name}> : {parameter_attr.type}")
return Attribute.parse(
f"#stream.parameter.named<{name}> : {parameter_attr.type}"
)
else:
scope = StringAttr.get(scoped_name[0])
name = StringAttr.get(scoped_name[1])
return Attribute.parse(f"#stream.parameter.named<{scope}::{name}> : {parameter_attr.type}")

return Attribute.parse(
f"#stream.parameter.named<{scope}::{name}> : {parameter_attr.type}"
)

# Check the rename map.
# Check with a fully-qualified name.
result = self.rename_map.get((orig_scope, name))
if result is not None:
return make_attr(norm_map_result(result))
# Check with just the
# Check with just the
result = self.rename_map.get(name)
if result is not None:
return make_attr(norm_map_result(result))
Expand Down
4 changes: 1 addition & 3 deletions python/turbine_models/custom_models/remap_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
continue
gguf_tensor_name = TENSOR_NAMES[tensor]
if MODEL_ARCH_NAMES[arch] in tensor_dict:
self.mapping[
tensor_dict[MODEL_ARCH_NAMES[arch]]
] = gguf_tensor_name
self.mapping[tensor_dict[MODEL_ARCH_NAMES[arch]]] = gguf_tensor_name
for bid in range(n_blocks):
for tensor, tensor_dict in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]:
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamo/importer_basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def foo():
f16 = torch.tensor([1.1, 2.2, 3.3, 4.4], dtype=torch.float16)
f32 = torch.tensor([1.1, 2.2, 3.3, 4.4], dtype=torch.float32)
return f16, f32

opt_foo = torch.compile(foo, backend=create_backend())
opt_foo()

Expand Down
8 changes: 6 additions & 2 deletions tests/dynamo/importer_dynamic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,16 @@ def forward(self, inp):
g = x / 32
return {"result": g}


class DynamicShapeStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a):
dynamic_shape = [a.size(0), a.size(1), a.size(2)]
x = torch.ops.aten.empty_strided(dynamic_shape, stride=[12, 4, 1]) # Default stride = [12, 4, 1]
x = torch.ops.aten.empty_strided(
dynamic_shape, stride=[12, 4, 1]
) # Default stride = [12, 4, 1]
y = x.copy_(a)
return y

Expand Down Expand Up @@ -173,7 +176,7 @@ def testDynamicShapeStrided(self):
"""
model = DynamicShapeStridedModule()
# inp_example = torch.rand(5, 7, 9)
inp_example = torch.randn(2, 3, 4) # input for default stride
inp_example = torch.randn(2, 3, 4) # input for default stride
f = dynamo.export(
model.forward,
aten_graph=True,
Expand All @@ -186,6 +189,7 @@ def testDynamicShapeStrided(self):
g, guards = f(a=inp_example)
g = import_compiler(g, [inp_example])


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
9 changes: 7 additions & 2 deletions tests/dynamo/mninst_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

import torch._dynamo.config

torch._dynamo.config.dynamic_shapes = False # TODO: https://github.com/nod-ai/SHARK-Turbine/issues/93
torch._dynamo.config.dynamic_shapes = (
False # TODO: https://github.com/nod-ai/SHARK-Turbine/issues/93
)


class MNISTDataLoader:
Expand Down Expand Up @@ -44,7 +46,10 @@ def get_train_loader(self):

def get_test_loader(self):
return DataLoader(
dataset=self.mnist_testset, batch_size=self.batch_size, shuffle=False, drop_last=True,
dataset=self.mnist_testset,
batch_size=self.batch_size,
shuffle=False,
drop_last=True,
)


Expand Down
2 changes: 2 additions & 0 deletions tests/examples/aot_mlp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ def _run(local_path: str):
path = REPO_DIR / local_path
subprocess.check_call([sys.executable, str(path)])


class AOTMLPTest(unittest.TestCase):
def testMLPExportSimple(self):
_run("examples/aot_mlp/mlp_export_simple.py")

def testMLPExportSimple(self):
_run("examples/aot_mlp/mlp_export_dynamic.py")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
16 changes: 9 additions & 7 deletions tests/models/hf_dict.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import torch

class ModelData():

class ModelData:
def __init__(self, input_shape, torch_dtype: torch.dtype, xfail: bool = True):
self.xfail = xfail
self.input_shape = input_shape
self.dtype = torch_dtype


model_dict = {
'distilgpt2': ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
'gpt2': ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
'gpt2-medium': ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
'bert-base-uncased': ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
'bert-large-uncased': ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
}
"distilgpt2": ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
"gpt2": ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
"gpt2-medium": ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
"bert-base-uncased": ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
"bert-large-uncased": ModelData(input_shape=(1, 1), torch_dtype=torch.int64),
}
18 changes: 13 additions & 5 deletions tests/models/transformer_builder_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,29 @@
from hf_dict import model_dict
import torch


class TestHFTransformerBuilder(unittest.TestCase):
pass


def create_test(model_name, model_data):
def test(self):
example_input = torch.ones(*model_data.input_shape, dtype=model_data.torch_dtype)
example_input = torch.ones(
*model_data.input_shape, dtype=model_data.torch_dtype
)
builder = HFTransformerBuilder(example_input, model_name, auto_tokenizer=None)
compiled_module = builder.get_compiled_module()
self.assertIsNotNone(compiled_module)

return test


for model_name, model_data in model_dict.items():
test_method = create_test(model_name, model_data)
test_method = unittest.expectedFailure(test_method) if model_data.xfail else test_method
setattr(TestHFTransformerBuilder, f'test_{model_name}', test_method)
test_method = (
unittest.expectedFailure(test_method) if model_data.xfail else test_method
)
setattr(TestHFTransformerBuilder, f"test_{model_name}", test_method)

if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()

0 comments on commit 4117974

Please sign in to comment.