Skip to content

Commit

Permalink
update dim name
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Aug 7, 2024
1 parent 0f5418c commit 15ee5ea
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions models/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
HIDDEN_DIM,
NUM_LAYERS,
)
token0_dim = torch.export.Dim("token0_dim", max=MAX_STEP_SEQ - 1)
dynamic_shapes_forw = {"arg0_1": None, "arg1_1": {1: token0_dim}}
state_arg0_dim = torch.export.Dim("state_arg0_dim", max=MAX_STEP_SEQ - 1)
dynamic_shapes_forw = {"arg0_1": None, "arg1_1": {1: state_arg0_dim}}
for dim_number in range(1, len(state_arg)):
current_dim_dict = {f"arg{dim_number + 1}_1": {1: token0_dim}}
current_dim_dict = {f"arg{dim_number + 1}_1": {1: state_arg0_dim}}
dynamic_shapes_forw = {**dynamic_shapes_forw, **current_dim_dict}
token, *state_update = self.forward(
x, *state_arg, dynamic_shapes=dynamic_shapes_forw
Expand Down

0 comments on commit 15ee5ea

Please sign in to comment.