diff --git a/models/turbine_models/custom_models/stateless_llama.py b/models/turbine_models/custom_models/stateless_llama.py index 8261c3cd..c1be1690 100644 --- a/models/turbine_models/custom_models/stateless_llama.py +++ b/models/turbine_models/custom_models/stateless_llama.py @@ -267,10 +267,10 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): HIDDEN_DIM, NUM_LAYERS, ) - token0_dim = torch.export.Dim("token0_dim", max=MAX_STEP_SEQ - 1) - dynamic_shapes_forw = {"arg0_1": None, "arg1_1": {1: token0_dim}} + state_arg0_dim = torch.export.Dim("state_arg0_dim", max=MAX_STEP_SEQ - 1) + dynamic_shapes_forw = {"arg0_1": None, "arg1_1": {1: state_arg0_dim}} for dim_number in range(1, len(state_arg)): - current_dim_dict = {f"arg{dim_number + 1}_1": {1: token0_dim}} + current_dim_dict = {f"arg{dim_number + 1}_1": {1: state_arg0_dim}} dynamic_shapes_forw = {**dynamic_shapes_forw, **current_dim_dict} token, *state_update = self.forward( x, *state_arg, dynamic_shapes=dynamic_shapes_forw