Skip to content

Commit

Permalink
Split models list up into groups, pipe through CLI option
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 17, 2024
1 parent 896c221 commit 5815e57
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 8 deletions.
7 changes: 7 additions & 0 deletions models/turbine_models/custom_models/torchbench/cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def is_valid_file(arg):
help="model ID as it appears in the torchbench models text file lists, or 'all' for batch export",
default="all",
)
p.add_argument(
"--model_lists",
type=Path,
nargs="*"
help="path to a JSON list of models to benchmark. One or more paths.",
default=["torchbench_models.json", "timm_models.json", "torchvision_models.json"],
)
p.add_argument(
"--external_weights_dir",
type=str,
Expand Down
23 changes: 15 additions & 8 deletions models/turbine_models/custom_models/torchbench/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import csv

torchbench_models_dict = {
torchbench_models_all = {
# "BERT_pytorch": {
# "dim": 128,
# }, # Dynamo Export Issue
Expand Down Expand Up @@ -420,10 +420,17 @@ def run_main(model_id, args, tb_dir, tb_args):

if __name__ == "__main__":
from turbine_models.custom_models.torchbench.cmd_opts import args, unknown

tb_dir = setup_torchbench_cwd()
if args.model_id.lower() == "all":
for name in torchbench_models_dict.keys():
run_main(name, args, tb_dir, unknown)
else:
run_main(args.model_id, args, tb_dir, unknown)
import json

torchbench_models_dict = json.load(args.model_list_json
for list in args.model_lists:
torchbench_models_dict = json.load(list)
with open(args.models_json, "r") as f:
torchbench_models_dict = json.load(file)

tb_dir = setup_torchbench_cwd()
if args.model_id.lower() == "all":
for name in torchbench_models_dict.keys():
run_main(name, args, tb_dir, unknown)
else:
run_main(args.model_id, args, tb_dir, unknown)
14 changes: 14 additions & 0 deletions models/turbine_models/custom_models/torchbench/timm_models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"timm_efficientnet": {
"dim": 128
},
"timm_regnet": {
"dim": 128
},
"timm_resnest": {
"dim": 256
},
"timm_vovnet": {
"dim": 128
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"pytorch_unet": {
"dim": 8
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"LearningToPaint": {
"dim": 1024
},
"alexnet": {
"dim": 1024
},
"densenet121": {
"dim": 64
},
"mnasnet1_0": {
"dim": 256
},
"mobilenet_v2": {
"dim": 128
},
"mobilenet_v3_large": {
"dim": 256
},
"resnet18": {
"dim": 512
},
"resnet50": {
"dim": 128
},
"resnext50_32x4d": {
"dim": 128
},
"shufflenet_v2_x1_0": {
"dim": 512
},
"squeezenet1_1": {
"dim": 512
}
}

0 comments on commit 5815e57

Please sign in to comment.