-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
82 lines (57 loc) · 2.04 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import importlib
importlib.import_module('gcn_dgl_wrapper')
importlib.import_module('gcn_spmm_cu')
import os
import argparse, json
def main():
# add command line arguments to parser
parser = argparse.ArgumentParser()
parser.add_argument('--config')
parser.add_argument('--gpu_id')
parser.add_argument('--model')
parser.add_argument('--dataset')
args = parser.parse_args()
if(args.config is not None):
with open(args.config) as f:
config = json.load(f)
else:
print('Please call function with parameter --config file.json')
exit()
print('Current configurations from config file:')
### GPU Configurations ###
# if gpu id is passed as command line argument
if(args.gpu_id is not None):
conf_gpu_usage = True
conf_gpu_id = int(args.gpu_id)
# otherwise retrieve it from config file
else:
conf_gpu_usage = config['gpu']['use']
conf_gpu_id = config['gpu']['id']
print('GPU usage: ' + str(conf_gpu_usage))
print('GPU ID: ' + str(conf_gpu_id))
### GNN Model Configurations ###
# if gnn model is passed as command line argument
if(args.model is not None):
conf_model = str(args.model)
# otherwise retrieve it from config file
else:
conf_model = config['model']
print('GNN Model: ' + conf_model)
### Dataset Configurations ###
# if dataset is passed as command line argument
if(args.dataset is not None):
conf_dataset = str(args.dataset)
# otherwise retrieve it from config file
else:
conf_dataset = config['dataset']
print('Dataset: ' + conf_dataset)
GCN_Class = getattr(importlib.import_module('gcn_dgl_wrapper'), 'GCNdgl')
gcn = GCN_Class(16,4)
print("Number of input dimensions of GCN model is: " + str(gcn.inD))
print("Number of putput dimensions of GCN model is: " + str(gcn.outD))
gcn.deneme()
GCN_SpMM_Layer = getattr(importlib.import_module('gcn_spmm_cu'), 'gcnlayer')
GCN_SpMM_Layer()
#lib.clean()
#lib.deneme()
main()