Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Aug 7, 2024
1 parent 15ee5ea commit d0da328
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/resnet_18.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CompiledResnet18Model(CompiledModule):
params = export_parameters(resnet_model.model)

def main(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)):
dynamic_shapes={"arg0_1": {0: torch.export.Dim("dim", max=15)}}
dynamic_shapes = {"arg0_1": {0: torch.export.Dim("dim", max=15)}}
return jittable(resnet_model.forward)(x, dynamic_shapes=dynamic_shapes)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
Expand Down
17 changes: 13 additions & 4 deletions models/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ class StateUpdateModule(CompiledModule):
def run_initialize(
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
):
dynamic_shapes_init={"arg0_1": {1: torch.export.Dim("dim", max=MAX_STEP_SEQ - 1)}}
dynamic_shapes_init = {
"arg0_1": {1: torch.export.Dim("dim", max=MAX_STEP_SEQ - 1)}
}
token, *state = self.initialize(x, dynamic_shapes=dynamic_shapes_init)
self.global_seq_step = IREE.tensor_dim(
state[0], 1
Expand Down Expand Up @@ -267,7 +269,9 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
HIDDEN_DIM,
NUM_LAYERS,
)
state_arg0_dim = torch.export.Dim("state_arg0_dim", max=MAX_STEP_SEQ - 1)
state_arg0_dim = torch.export.Dim(
"state_arg0_dim", max=MAX_STEP_SEQ - 1
)
dynamic_shapes_forw = {"arg0_1": None, "arg1_1": {1: state_arg0_dim}}
for dim_number in range(1, len(state_arg)):
current_dim_dict = {f"arg{dim_number + 1}_1": {1: state_arg0_dim}}
Expand Down Expand Up @@ -340,9 +344,14 @@ def run_cached_initialize(
HIDDEN_DIM,
NUM_LAYERS,
)
state_arg0_dim1 = torch.export.Dim("state_arg0_dim1", max=MAX_STEP_SEQ - 1)
state_arg0_dim1 = torch.export.Dim(
"state_arg0_dim1", max=MAX_STEP_SEQ - 1
)
x_dim = torch.export.Dim("x_dim", max=MAX_STEP_SEQ - 1)
dynamic_shapes_forw = {"arg0_1": {1: x_dim}, "arg1_1": {1: state_arg0_dim1}}
dynamic_shapes_forw = {
"arg0_1": {1: x_dim},
"arg1_1": {1: state_arg0_dim1},
}
for dim_number in range(1, len(state_arg)):
current_dim_dict = {f"arg{dim_number + 1}_1": {1: state_arg0_dim1}}
dynamic_shapes_forw = {**dynamic_shapes_forw, **current_dim_dict}
Expand Down

0 comments on commit d0da328

Please sign in to comment.