Skip to content

Commit

Permalink
update for loop indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Aug 7, 2024
1 parent d0da328 commit 35209fb
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions models/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
"state_arg0_dim", max=MAX_STEP_SEQ - 1
)
dynamic_shapes_forw = {"arg0_1": None, "arg1_1": {1: state_arg0_dim}}
for dim_number in range(1, len(state_arg)):
current_dim_dict = {f"arg{dim_number + 1}_1": {1: state_arg0_dim}}
for state_arg_idx in range(2, len(state_arg) + 1):
current_dim_dict = {f"arg{state_arg_idx}_1": {1: state_arg0_dim}}
dynamic_shapes_forw = {**dynamic_shapes_forw, **current_dim_dict}
token, *state_update = self.forward(
x, *state_arg, dynamic_shapes=dynamic_shapes_forw
Expand Down Expand Up @@ -352,8 +352,8 @@ def run_cached_initialize(
"arg0_1": {1: x_dim},
"arg1_1": {1: state_arg0_dim1},
}
for dim_number in range(1, len(state_arg)):
current_dim_dict = {f"arg{dim_number + 1}_1": {1: state_arg0_dim1}}
for state_arg_idx in range(2, len(state_arg) + 1):
current_dim_dict = {f"arg{state_arg_idx}_1": {1: state_arg0_dim1}}
dynamic_shapes_forw = {**dynamic_shapes_forw, **current_dim_dict}
token, *state = self.cached_initialize(
x, *state_arg, dynamic_shapes=dynamic_shapes_forw
Expand Down

0 comments on commit 35209fb

Please sign in to comment.