Skip to content

Commit

Permalink
Support bf16 weight and dynamic shape (#242)
Browse files Browse the repository at this point in the history
* Support bf16 weight and dynamic shape

* apply pyink formatter
  • Loading branch information
yuanzhedong authored Nov 12, 2024
1 parent 7d59fcd commit ffa84f7
Showing 1 changed file with 14 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,14 @@ def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
total_size *= dim

if size_limit < 0 or size_limit >= total_size:
return json.dumps(tensor.cpu().detach().numpy().tolist())
return json.dumps(
tensor.cpu().detach().to(torch.float32).numpy().tolist()
)

return json.dumps(
(tensor.cpu().detach().numpy().flatten())[:size_limit].tolist()
(tensor.cpu().detach().to(torch.float32).numpy().flatten())[
:size_limit
].tolist()
)

def add_node_attrs(self, fx_node: torch.fx.node.Node, node: GraphNode):
Expand Down Expand Up @@ -204,7 +208,14 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
node.outputsMetadata.append(metadata)
elif isinstance(out_vals, torch.Tensor):
dtype = str(out_vals.dtype)
shape = json.dumps(out_vals.shape)
shape = json.dumps(
list(
map(
lambda x: int(x) if str(x).isdigit() else str(x),
out_vals.shape,
)
)
)
metadata = MetadataItem(
id='0', attrs=[KeyValue(key='tensor_shape', value=dtype + shape)]
)
Expand Down

0 comments on commit ffa84f7

Please sign in to comment.