Skip to content

Commit dc15d53

Browse files
authored
Add ability to gather and analyse some model metadata (nod-ai#376)
## Usage: When running `python run.py <other-args> --get-metadata`, this will save a dictionary with the model size and op frequencies to the log directory. After a run, you can use `python utils/find_duplicate_models.py` to save or print a json dump of redundant models. ### Options: - "-s" "--simplified" will only return the list of model names (doesn't include the corresponding metadata). - "-o" "--output" allows specifying the name of a json file you want to save the result to. - "-r" "--rundirectory" allows specifying a different run directory to search, if `run.py` was run with a non-default run directory arg. ## Sample: I saved the tests below to a file called `sample.txt`. ``` add_test model--bart-base-booksum--KamilAin model--bart-base-cnn--ainize model--bart-base-few-shot-k-1024-finetuned-squad-seed-2--anas-awadalla model--bart-base-few-shot-k-1024-finetuned-squad-seed-4--anas-awadalla ``` With a clean `test-run` directory, I ran ```shell python run.py --testsfile=sample.txt --stages "setup" --get-metadata ``` The result of running ```shell python utils/find_duplicate_models.py -s ``` was: ```json [ [ "model--bart-base-booksum--KamilAin", "model--bart-base-cnn--ainize" ], [ "model--bart-base-few-shot-k-1024-finetuned-squad-seed-4--anas-awadalla", "model--bart-base-few-shot-k-1024-finetuned-squad-seed-2--anas-awadalla" ] ] ``` and without the `-s` arg, it includes the metadata for each grouping: ```json [ { "models": [ "model--bart-base-booksum--KamilAin", "model--bart-base-cnn--ainize" ], "shared_metadata": { "model_size": 712772272, "op_frequency": { "Add": 227, "Cast": 13, "Concat": 188, "Constant": 886, "ConstantOfShape": 6, "Div": 44, "Equal": 5, "Erf": 12, "Expand": 5, "Gather": 64, "Less": 1, "MatMul": 133, "Mul": 99, "Pow": 32, "Range": 3, "ReduceMean": 64, "Reshape": 187, "Shape": 67, "Slice": 2, "Softmax": 18, "Sqrt": 32, "Squeeze": 2, "Sub": 35, "Transpose": 90, "Unsqueeze": 325, "Where": 8 } } }, { "models": [ "model--bart-base-few-shot-k-1024-finetuned-squad-seed-4--anas-awadalla", "model--bart-base-few-shot-k-1024-finetuned-squad-seed-2--anas-awadalla" ], "shared_metadata": { "model_size": 558176646, "op_frequency": { "Add": 229, "Cast": 17, "Concat": 193, "Constant": 937, "ConstantOfShape": 12, "Div": 44, "Equal": 10, "Erf": 12, "Expand": 11, "Gather": 70, "Less": 1, "MatMul": 133, "Mul": 103, "Pow": 32, "Range": 6, "ReduceMean": 64, "Reshape": 191, "ScatterND": 2, "Shape": 83, "Slice": 7, "Softmax": 18, "Split": 1, "Sqrt": 32, "Squeeze": 4, "Sub": 35, "Transpose": 90, "Unsqueeze": 333, "Where": 13 } } } ] ```
1 parent e7ce1bc commit dc15d53

File tree

4 files changed

+122
-2
lines changed

4 files changed

+122
-2
lines changed

alt_e2eshark/e2e_testing/framework.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ def update_opset_version_and_overwrite(self):
128128
og_model, self.opset_version
129129
)
130130
onnx.save(model, self.model)
131+
132+
def get_metadata(self):
133+
model_size = os.path.getsize(self.model)
134+
freq = get_op_frequency(self.model)
135+
metadata = {"model_size" : model_size, "op_frequency" : freq}
136+
return metadata
137+
138+
131139

132140
# TODO: extend TestModel to a union, or make TestModel a base class when supporting other frontends
133141
TestModel = OnnxModelInfo
@@ -161,6 +169,7 @@ def benchmark(self, artifact: CompiledOutput, input: TestTensors, repetitions: i
161169
"""returns a float representing inference time in ms"""
162170
pass
163171

172+
164173
class Test(NamedTuple):
165174
"""Used to store the name and TestInfo constructor for a registered test"""
166175

alt_e2eshark/e2e_testing/logging_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def scan_dir_del_not_logs(dir):
7171
for root, dirs, files in os.walk(dir):
7272
for name in files:
7373
curr_file = os.path.join(root, name)
74-
if not name.endswith(".log") and name != "benchmark.json":
74+
if not name.endswith(".log") and not name.endswith(".json"):
7575
removed_files.append(curr_file)
7676
for file in removed_files:
7777
os.remove(file)

alt_e2eshark/run.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def main(args):
133133
stages,
134134
args.load_inputs,
135135
int(args.cleanup),
136+
args.get_metadata,
136137
)
137138

138139
if args.report:
@@ -142,7 +143,7 @@ def main(args):
142143

143144

144145
def run_tests(
145-
test_list: List[Test], config: TestConfig, parent_log_dir: str, no_artifacts: bool, verbose: bool, stages: List[str], load_inputs: bool, cleanup: int,
146+
test_list: List[Test], config: TestConfig, parent_log_dir: str, no_artifacts: bool, verbose: bool, stages: List[str], load_inputs: bool, cleanup: int, get_metadata=bool,
146147
) -> Dict[str, Dict]:
147148
"""runs tests in test_list based on config. Returns a dictionary containing the test statuses."""
148149
# TODO: multi-process
@@ -190,6 +191,10 @@ def run_tests(
190191
# TODO: Figure out how to factor this out of run.py
191192
if not os.path.exists(inst.model):
192193
inst.construct_model()
194+
if get_metadata:
195+
metadata = inst.get_metadata()
196+
metadata_file = Path(log_dir) / "metadata.json"
197+
save_dict(metadata, metadata_file)
193198

194199
artifact_save_to = None if no_artifacts else log_dir
195200
# generate mlir from the instance using the config
@@ -449,6 +454,12 @@ def _get_argparse():
449454
default="report.md",
450455
help="output filename for the report summary.",
451456
)
457+
parser.add_argument(
458+
"--get-metadata",
459+
action="store_true",
460+
default=False,
461+
help="save some model metadata to log_dir/metadata.json"
462+
)
452463
# parser.add_argument(
453464
# "-d",
454465
# "--todtype",
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from pathlib import Path
2+
import argparse
3+
from typing import Union, Dict, Any, Optional
4+
import json
5+
import io
6+
7+
ROOT = Path(__file__).parents[1]
8+
9+
10+
class HashableDict(dict):
11+
"""a hashable dictionary, used to invert a dictionary with dictionary values"""
12+
13+
def __hash__(self):
14+
return hash(tuple(sorted(self.items())))
15+
16+
17+
def load_json_dict(filepath: Union[str, Path]) -> Dict[str, Any]:
18+
with open(filepath) as contents:
19+
loaded_dict = json.load(contents)
20+
return loaded_dict
21+
22+
23+
def save_to_json(jsonable_object, name_json: Optional[str] = None):
24+
"""Saves an object to a json file with the given name, or prints result."""
25+
dict_str = json.dumps(
26+
jsonable_object,
27+
indent=4,
28+
sort_keys=True,
29+
separators=(",", ": "),
30+
ensure_ascii=False,
31+
)
32+
if not name_json:
33+
print(dict_str)
34+
return
35+
path_json = ROOT / f"{name_json.stem}.json"
36+
with io.open(path_json, "w", encoding="utf8") as outfile:
37+
outfile.write(dict_str)
38+
39+
40+
def get_groupings(metadata_dicts: Dict[str, Dict]) -> Dict:
41+
"""gets a multi-valued inverse of metatdata_dicts"""
42+
groupings = dict()
43+
for key, value in metadata_dicts.items():
44+
value["op_frequency"] = HashableDict(value["op_frequency"])
45+
hashable = HashableDict(value)
46+
if hashable in groupings.keys():
47+
groupings[hashable].append(key)
48+
else:
49+
groupings[hashable] = [key]
50+
return groupings
51+
52+
53+
def main(args):
54+
run_dir = ROOT / args.rundirectory
55+
metadata_dicts = dict()
56+
for x in run_dir.glob("*/*.json"):
57+
if x.name == "metadata.json":
58+
test_name = x.parent.name
59+
metadata_dicts[test_name] = load_json_dict(x)
60+
61+
groupings = get_groupings(metadata_dicts)
62+
found_redundancies = []
63+
for key, value in groupings.items():
64+
if len(value) > 1:
65+
found_redundancies.append(
66+
value if args.simplified else {"models": value, "shared_metadata": key}
67+
)
68+
save_to_json(found_redundancies, args.output)
69+
70+
71+
def _get_argparse():
72+
msg = "After running run.py with the flag --get-metadata, use this tool to find duplicate models."
73+
parser = argparse.ArgumentParser(
74+
prog="find_duplicate_models.py", description=msg, epilog=""
75+
)
76+
77+
parser.add_argument(
78+
"-r",
79+
"--rundirectory",
80+
default="test-run",
81+
help="The directory containing run.py results",
82+
)
83+
parser.add_argument(
84+
"-o",
85+
"--output",
86+
help="specify an output json file",
87+
)
88+
parser.add_argument(
89+
"-s",
90+
"--simplified",
91+
action="store_true",
92+
default=False,
93+
help="pass this arg to only print redundant model lists, without the corresponding metadata.",
94+
)
95+
return parser
96+
97+
98+
if __name__ == "__main__":
99+
parser = _get_argparse()
100+
main(parser.parse_args())

0 commit comments

Comments
 (0)