Skip to content

Commit

Permalink
Update pytorch_exported_program_adater_impl.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yijie-yang authored May 13, 2024
1 parent 64eb2a5 commit 9eb3804
Showing 1 changed file with 19 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@

import json
import types
from typing import Dict

import torch
import torch.fx
from torch.fx import _pytree as fx_pytree

from .graph_builder import Graph, GraphNode, IncomingEdge, KeyValue, MetadataItem
from .graph_builder import Graph
from .graph_builder import GraphNode
from .graph_builder import IncomingEdge
from .graph_builder import KeyValue
from .graph_builder import MetadataItem
from .types import ModelExplorerGraphs


Expand All @@ -32,7 +35,9 @@ def __init__(self, ep: torch.export.ExportedProgram):
self.gm = self.ep.graph_module
self.inputs_map = self.get_inputs_map()

def _graph_module_flat_inputs(self, ep: torch.export.ExportedProgram, args, kwargs):
def legacy_graph_module_flat_inputs(
self, ep: torch.export.ExportedProgram, args, kwargs
):
"""Transform args, kwargs of __call__ to args for graph_module.
self.graph_module takes stuff from state dict as inputs.
Expand Down Expand Up @@ -61,8 +66,7 @@ def _graph_module_flat_inputs(self, ep: torch.export.ExportedProgram, args, kwar
param_buffer_keys = (
ep.graph_signature.parameters + ep.graph_signature.buffers
)
param_buffer_values = tuple(
ep.state_dict[key] for key in param_buffer_keys)
param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys)

if hasattr(ep.graph_signature, 'lifted_tensor_constants'):
ordered_tensor_constants = tuple(
Expand All @@ -83,9 +87,14 @@ def get_inputs_map(self):
)
return inputs_map

input_tensors = self._graph_module_flat_inputs(
self.ep, *self.ep.example_inputs
)
input_tensors = None
if hasattr(self.ep, '_graph_module_flat_inputs'):
input_tensors = self.ep._graph_module_flat_inputs(*self.ep.example_inputs)
else:
# Backward compatibility with torch 2.2.x
input_tensors = self.legacy_graph_module_flat_inputs(
self.ep, *self.ep.example_inputs
)
for input_spec, tensor in zip(
self.ep.graph_signature.input_specs, input_tensors
):
Expand Down Expand Up @@ -176,15 +185,14 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
if out_vals is None:
return

if isinstance(out_vals, tuple):
if isinstance(out_vals, (tuple, list)):
for idx, val in enumerate(out_vals):
metadata = MetadataItem(id=str(idx), attrs=[])
if val is None:
continue
dtype = str(val.dtype)
shape = json.dumps(val.shape)
metadata.attrs.append(
KeyValue(key='tensor_shape', value=dtype + shape))
metadata.attrs.append(KeyValue(key='tensor_shape', value=dtype + shape))
node.outputsMetadata.append(metadata)
else:
dtype = str(out_vals.dtype)
Expand Down

0 comments on commit 9eb3804

Please sign in to comment.