10
10
11
11
12
12
def evaluate_model (model , test_loader , mlb_encoder , mt_labels , device ):
13
- meter = Meter (mt_labels = mt_labels , mlb_encoder = mlb_encoder )
13
+ meter = Meter (mt_labels = mt_labels , mlb_encoder = mlb_encoder , k = 1 )
14
14
f1 = 0
15
15
16
16
for i , (dev_x , dev_mask , dev_y ) in enumerate (test_loader ):
@@ -47,6 +47,8 @@ def evaluate():
47
47
if not os .path .exists (os .path .join (args .output_path , lang )):
48
48
os .makedirs (os .path .join (args .output_path , lang ))
49
49
50
+ pk_scores = []
51
+ rk_scores = []
50
52
f1k_scores = []
51
53
f1k_mt_scores = []
52
54
f1k_domain_scores = []
@@ -71,6 +73,8 @@ def evaluate():
71
73
format (meter .f1k , meter .f1k_mt , meter .f1k_domain ,
72
74
meter .ndcg_1 , meter .ndcg_3 , meter .ndcg_5 , meter .ndcg_10 ))
73
75
76
+ pk_scores .append (meter .pk )
77
+ rk_scores .append (meter .rk )
74
78
f1k_scores .append (meter .f1k )
75
79
f1k_mt_scores .append (meter .f1k_mt )
76
80
f1k_domain_scores .append (meter .f1k_domain )
@@ -82,21 +86,25 @@ def evaluate():
82
86
print ("\n Overall results for language '{}' - "
83
87
"F1@6: {:.2f} ± ({:.2f}), F1@6_MT: {:.2f} ± ({:.2f}), F1@6_DO: {:.2f} ± ({:.2f})\n "
84
88
" "
89
+ "P@K: {:.2f} ± ({:.2f}), R@K_DO: {:.2f} ± ({:.2f})"
90
+ " "
85
91
"NDCG@1: {:.2f} ± ({:.2f}), NDCG@3: {:.2f} ± ({:.2f}), NDCG@5: {:.2f} ± ({:.2f}), NDCG@10: {:.2f} ± ({:.2f})" .
86
- format (lang ,
87
- np .mean (f1k_scores ), np .std (f1k_scores ),
88
- np .mean (f1k_mt_scores ), np .std (f1k_mt_scores ),
89
- np .mean (f1k_domain_scores ), np .std (f1k_domain_scores ),
90
- np .mean (ndcg_1_scores ), np .std (ndcg_1_scores ),
91
- np .mean (ndcg_3_scores ), np .std (ndcg_3_scores ),
92
- np .mean (ndcg_5_scores ), np .std (ndcg_5_scores ),
93
- np .mean (ndcg_10_scores ), np .std (ndcg_10_scores )))
92
+ format (lang ,
93
+ np .mean (f1k_scores ), np .std (f1k_scores ),
94
+ np .mean (f1k_mt_scores ), np .std (f1k_mt_scores ),
95
+ np .mean (f1k_domain_scores ), np .std (f1k_domain_scores ),
96
+ np .mean (pk_scores ), np .std (pk_scores ),
97
+ np .mean (rk_scores ), np .std (rk_scores ),
98
+ np .mean (ndcg_1_scores ), np .std (ndcg_1_scores ),
99
+ np .mean (ndcg_3_scores ), np .std (ndcg_3_scores ),
100
+ np .mean (ndcg_5_scores ), np .std (ndcg_5_scores ),
101
+ np .mean (ndcg_10_scores ), np .std (ndcg_10_scores )))
94
102
95
103
96
104
if __name__ == "__main__" :
97
105
parser = argparse .ArgumentParser ()
98
- parser .add_argument ("--config" , type = str , default = "pyeurovoc/ configs/models.yml" , help = "Tokenizer used for each language." )
99
- parser .add_argument ("--mt_labels" , type = str , default = "pyeurovoc/resources /mt_labels.json" )
106
+ parser .add_argument ("--config" , type = str , default = "configs/models.yml" , help = "Tokenizer used for each language." )
107
+ parser .add_argument ("--mt_labels" , type = str , default = "configs /mt_labels.json" )
100
108
parser .add_argument ("--data_path" , type = str , default = "data/eurovoc" , help = "Path to the EuroVoc data." )
101
109
parser .add_argument ("--device" , type = str , default = "cpu" , help = "Device to train on." )
102
110
parser .add_argument ("--models_path" , type = str , default = "models" , help = "Path of the saved models." )
0 commit comments