Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo fx completeness #162

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions python/shark_turbine/dynamo/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,14 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType:
val = node.meta.get("val")
if tensor_meta is not None:
assert isinstance(tensor_meta, TensorMetadata)
# TODO: We should probably only be doing this if "vanilla".
# Specifically, there are strides/qparams/etc on there that
# should be annotated somewhere.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/139
return self.tensor_metadata_to_type(tensor_meta)
# Quantized tensor meta data is not preserved in our lowering,
# so throw error instead of silently doing wrong thing.
if (tensor_meta.is_quantized):
raise NotImplementedError(
f"Quantized tensor meta data is not supported."
)
else:
return self.tensor_metadata_to_type(tensor_meta)
elif val is not None:
# some nodes with symbolic inputs pass a 'val' attribute rather than
# tensor_meta
Expand Down Expand Up @@ -630,23 +633,16 @@ def _import_torch_op_overload(
op_overload = getattr(op_overload, op_attrs[i])
schema = op_overload._schema

if not self._c.is_registered_operation(mlir_op_name):
# TODO: Implement a config setting to allow these to flow through.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/141
raise NotImplementedError(
f"Unimplemented torch op in the IREE compiler: '{mlir_op_name}' "
f"(either implement this op/variant or configure the compiler to "
f"allow unknown operations and fallback to PyTorch)."
)

return_count = len(schema.returns)
if return_count == 1:
# Unary return directly maps a single meta["val"] and cannot be subscripted.
# if "tensor_meta" is None, this will throw unsupported placeholder node error
result_types = [self._cc.node_val_to_type(node)]
elif return_count == 0:
# TODO: Implement (https://github.com/nod-ai/SHARK-Turbine/issues/142)
raise NotImplementedError("FIXME: Zero ATen results")
# Some torch ops do have 0 returns, and these are supported with ZeroResults
# op trait. Python bindings for IR creation allow us to pass empty result_types
# for such ops. Therefore, we pass an empty result types for these cases.
result_types = []
else:
# Multi-return will unpack the meta["val"] and trigger our getitem subscripting
# short-circuit above. Note that if we ever choose to also fully reify Python
Expand All @@ -663,8 +659,6 @@ def _import_torch_op_overload(
operands = []
for i, parameter in enumerate(schema.arguments):
if parameter.kwarg_only and parameter.name in node.kwargs:
# TODO: Nice error if KeyError.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/143
operands.append(
self._import_argument(
loc, node.kwargs[parameter.name], parameter.type
Expand All @@ -681,12 +675,23 @@ def _import_torch_op_overload(
)
)

operation = Operation.create(
mlir_op_name,
results=result_types,
operands=operands,
loc=loc,
)
# Support unregistered torch ops using torch.operator.
# torch.operator is used to represent ops from registry
# which haven't been generated by torch_ods_gen.py.
if not self._c.is_registered_operation(mlir_op_name):
operation = Operation.create(
"torch.operator",
results=result_types,
operands=operands,
loc=loc,
)
else:
operation = Operation.create(
mlir_op_name,
results=result_types,
operands=operands,
loc=loc,
)

# Record value mapping.
for i, value in enumerate(operation.results):
Expand Down Expand Up @@ -830,18 +835,15 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value:
if isinstance(arg, list):
return self._import_list_argument(loc, arg, expected_jit_type)

# The LITERAL_CONVERTER_MAP maps each arg to its respective constant
# of the expected jit IR type (types like torch.dtype will form a chain of
# maps to get to constant of expected_jit_type).
cvt = LITERAL_CONVERTER_MAP.lookup(type(arg))
if cvt is None:
raise RuntimeError(f"Unhandled default value ({arg.__class__}): {arg})")
with loc:
return cvt(arg, self, self._cc)

# TODO: Support torch specific types which show up in function schemas.
# These all require an expected_jit_type to convert.
# torch.dtype, torch.device, torch.memory_format, torch.layout
# list
# See: https://github.com/nod-ai/SHARK-Turbine/issues/144


class TypeSubclassMap:
"""Mapping of super-types to values.
Expand Down
Loading