diff --git a/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py b/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py index 571851dc..67534f40 100644 --- a/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py +++ b/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py @@ -15,13 +15,16 @@ import json import types -from typing import Dict import torch import torch.fx from torch.fx import _pytree as fx_pytree -from .graph_builder import Graph, GraphNode, IncomingEdge, KeyValue, MetadataItem +from .graph_builder import Graph +from .graph_builder import GraphNode +from .graph_builder import IncomingEdge +from .graph_builder import KeyValue +from .graph_builder import MetadataItem from .types import ModelExplorerGraphs @@ -32,7 +35,9 @@ def __init__(self, ep: torch.export.ExportedProgram): self.gm = self.ep.graph_module self.inputs_map = self.get_inputs_map() - def _graph_module_flat_inputs(self, ep: torch.export.ExportedProgram, args, kwargs): + def legacy_graph_module_flat_inputs( + self, ep: torch.export.ExportedProgram, args, kwargs + ): """Transform args, kwargs of __call__ to args for graph_module. self.graph_module takes stuff from state dict as inputs. @@ -61,8 +66,7 @@ def _graph_module_flat_inputs(self, ep: torch.export.ExportedProgram, args, kwar param_buffer_keys = ( ep.graph_signature.parameters + ep.graph_signature.buffers ) - param_buffer_values = tuple( - ep.state_dict[key] for key in param_buffer_keys) + param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys) if hasattr(ep.graph_signature, 'lifted_tensor_constants'): ordered_tensor_constants = tuple( @@ -83,9 +87,14 @@ def get_inputs_map(self): ) return inputs_map - input_tensors = self._graph_module_flat_inputs( - self.ep, *self.ep.example_inputs - ) + input_tensors = None + if hasattr(self.ep, '_graph_module_flat_inputs'): + input_tensors = self.ep._graph_module_flat_inputs(*self.ep.example_inputs) + else: + # Backward compatibility with torch 2.2.x + input_tensors = self.legacy_graph_module_flat_inputs( + self.ep, *self.ep.example_inputs + ) for input_spec, tensor in zip( self.ep.graph_signature.input_specs, input_tensors ): @@ -176,15 +185,14 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode): if out_vals is None: return - if isinstance(out_vals, tuple): + if isinstance(out_vals, (tuple, list)): for idx, val in enumerate(out_vals): metadata = MetadataItem(id=str(idx), attrs=[]) if val is None: continue dtype = str(val.dtype) shape = json.dumps(val.shape) - metadata.attrs.append( - KeyValue(key='tensor_shape', value=dtype + shape)) + metadata.attrs.append(KeyValue(key='tensor_shape', value=dtype + shape)) node.outputsMetadata.append(metadata) else: dtype = str(out_vals.dtype)