From 285ef9b55c2277cdaa7a993130402209a8d77564 Mon Sep 17 00:00:00 2001 From: roll-away <220250881@seu.edu.cn> Date: Mon, 29 Dec 2025 12:16:49 +0000 Subject: [PATCH 1/2] fix graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py --- .../dim_gen_passes/non_batch_call_function_arange_pass.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py b/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py index 9c55aa518..ead3192f4 100644 --- a/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py +++ b/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py @@ -85,7 +85,11 @@ def create_new_node(node): self.node_target(), args=new_node_args, kwargs=node.kwargs ) - return new_node + safe_aranged_node = new_graph.call_function( + torch.remainder, args=(new_node, 512) + ) + + return safe_aranged_node for node in traced_module.graph.nodes: val_map[node] = create_new_node(node) From 4f29827c801e7a23756237f33190234f8affc9d2 Mon Sep 17 00:00:00 2001 From: roll-away <220250881@seu.edu.cn> Date: Tue, 30 Dec 2025 08:34:31 +0000 Subject: [PATCH 2/2] fix graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py --- .../non_batch_call_function_arange_pass.py | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py b/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py index ead3192f4..dfd81479b 100644 --- a/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py +++ b/graph_net/torch/dim_gen_passes/non_batch_call_function_arange_pass.py @@ -69,6 +69,24 @@ def get_new_node_arg(i, arg, len_args): return size_node + # def create_new_node(node): + # if not (self._node_need_rewrite(node) and last_node_axis is not None): + # # Copy other nodes to the new graph + # new_node = new_graph.node_copy(node, lambda x: val_map[x]) + # try_reset_last_node_axis(node=node, new_node=new_node) + # return new_node + + # new_node_args = tuple( + # get_new_node_arg(i, arg, len(node.args)) + # for i, arg in enumerate(node.args) + # ) + + # new_node = new_graph.call_function( + # self.node_target(), args=new_node_args, kwargs=node.kwargs + # ) + + # return new_node + def create_new_node(node): if not (self._node_need_rewrite(node) and last_node_axis is not None): # Copy other nodes to the new graph @@ -85,11 +103,21 @@ def create_new_node(node): self.node_target(), args=new_node_args, kwargs=node.kwargs ) - safe_aranged_node = new_graph.call_function( - torch.remainder, args=(new_node, 512) - ) - - return safe_aranged_node + static_limit = _get_static_limit(node) + if static_limit != float("inf"): + max_val = int(static_limit - 1) + new_node = new_graph.call_function( + torch.clamp, args=(new_node, 0, max_val) + ) + return new_node + + def _get_static_limit(node): + static_limit = float("inf") + for user in node.users: + if user.op == "call_function" and ("embedding" in str(user.target)): + indexed_dim_size = user.args[1].meta["tensor_meta"].shape[0] + static_limit = min(static_limit, indexed_dim_size) + return static_limit for node in traced_module.graph.nodes: val_map[node] = create_new_node(node)