-
Notifications
You must be signed in to change notification settings - Fork 1
/
transe_eval.py
76 lines (61 loc) · 1.82 KB
/
transe_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import openke
from openke.config import Trainer, Tester
from openke.module.model import TransE
from openke.module.loss import MarginLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader
import joblib
import torch
import numpy as np
from collections import defaultdict
import argparse
import os
import sys
import timeit
from data import (
TASK_REV_MEDIUMHAND,
TASK_LABELS,
)
import metrics
if not os.path.exists('checkpoint'):
os.makedirs('checkpoint')
# dataloader for training
train_dataloader = TrainDataLoader(
in_path = "./data/kge/openke/",
nbatches = 100,
threads = 8,
sampling_mode = "normal",
bern_flag = 1,
filter_flag = 1,
neg_ent = 25,
neg_rel = 0)
# dataloader for test
test_dataloader = TestDataLoader("./data/kge/openke/", "link")
# define the model
transe = TransE(
ent_tot = train_dataloader.get_ent_tot(),
rel_tot = train_dataloader.get_rel_tot(),
dim = 200,
p_norm = 1,
norm_flag = True)
# define the loss function
model = NegativeSampling(
model = transe,
loss = MarginLoss(margin = 5.0),
batch_size = train_dataloader.get_batch_size()
)
start_time = timeit.default_timer()
# train the model
trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, alpha = 1.0, use_gpu = True)
trainer.run()
transe.save_checkpoint('./checkpoint/transe.ckpt')
stop_time = timeit.default_timer()
print('average training time: {}'.format((stop_time-start_time)/1000))
start_time = timeit.default_timer()
# test the model
transe.load_checkpoint('./checkpoint/transe.ckpt')
tester = Tester(model = transe, data_loader = test_dataloader, use_gpu = True)
tester.run_link_prediction(type_constrain = False)
stop_time = timeit.default_timer()
print('link prediction testing time: {}'.format((stop_time-start_time)))