-
Notifications
You must be signed in to change notification settings - Fork 11
/
run.py
68 lines (53 loc) · 2 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import sys
import os
import json
from src.pinsage import model
from src.pinsage import process_amazon
def main(targets):
if 'help' in targets:
help_msg = """ Possible targets:
help: program help message
data: processes data
pinsage: trains pinsage model on train data
graphsage: trains graphsage model on train data
test: runs model on test data
all (default): trains and tests both models
clean: removes all output files"""
print(help_msg)
return
if len(targets) == 0 or 'all' in targets:
targets = ['data', 'pinsage', 'graphsage', 'test']
# if 'test' in targets:
# with open("config/{0}-test-params.json".format(model)) as fh:
# data_cfg = json.load(fh)
# process data using config
if 'data' in targets:
config_dir = "./config"
config_fn = "data-params.json"
with open(os.path.join(config_dir, config_fn)) as fh:
data_cfg = json.load(fh)
dataset = process_amazon.main(data_cfg)
# train model using config
pinsage_model, graphsage_model = None, None
if 'pinsage' in targets:
# Load config
config_dir = "./config"
config_fn = "pinsage-model-params.json"
with open(os.path.join(config_dir, config_fn)) as fh:
pinsage_model_cfg = json.load(fh)
print("Training model embeddings...")
item_embeddings = model.train(dataset, pinsage_model_cfg)
if 'test' in targets:
print("Testing model embeddings...")
rec = model.test(dataset, pinsage_model_cfg, item_embeddings)
if 'graphsage' in targets:
config_dir = "./config"
config_fn = "graphsage-model-params.json"
with open(os.path.join(config_dir, config_fn)) as fh:
graphsage_model_cfg = json.load(fh)
graphsage_model = train(data_cfg, graphsage_model_cfg)
graphsage_model.save() # save model as pth
return
if __name__ == "__main__":
targets = sys.argv[1:]
main(targets)