Skip to content

Commit 9ef1ee5

Browse files
committed
Created prod version
1 parent 31116ca commit 9ef1ee5

File tree

4 files changed

+23
-14
lines changed

4 files changed

+23
-14
lines changed

configs/mt_labels.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

evaluate.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
def evaluate_model(model, test_loader, mlb_encoder, mt_labels, device):
13-
meter = Meter(mt_labels=mt_labels, mlb_encoder=mlb_encoder)
13+
meter = Meter(mt_labels=mt_labels, mlb_encoder=mlb_encoder, k=1)
1414
f1 = 0
1515

1616
for i, (dev_x, dev_mask, dev_y) in enumerate(test_loader):
@@ -47,6 +47,8 @@ def evaluate():
4747
if not os.path.exists(os.path.join(args.output_path, lang)):
4848
os.makedirs(os.path.join(args.output_path, lang))
4949

50+
pk_scores = []
51+
rk_scores = []
5052
f1k_scores = []
5153
f1k_mt_scores = []
5254
f1k_domain_scores = []
@@ -71,6 +73,8 @@ def evaluate():
7173
format(meter.f1k, meter.f1k_mt, meter.f1k_domain,
7274
meter.ndcg_1, meter.ndcg_3, meter.ndcg_5, meter.ndcg_10))
7375

76+
pk_scores.append(meter.pk)
77+
rk_scores.append(meter.rk)
7478
f1k_scores.append(meter.f1k)
7579
f1k_mt_scores.append(meter.f1k_mt)
7680
f1k_domain_scores.append(meter.f1k_domain)
@@ -82,21 +86,25 @@ def evaluate():
8286
print("\nOverall results for language '{}' - "
8387
"F1@6: {:.2f} ± ({:.2f}), F1@6_MT: {:.2f} ± ({:.2f}), F1@6_DO: {:.2f} ± ({:.2f})\n"
8488
" "
89+
"P@K: {:.2f} ± ({:.2f}), R@K_DO: {:.2f} ± ({:.2f})"
90+
" "
8591
"NDCG@1: {:.2f} ± ({:.2f}), NDCG@3: {:.2f} ± ({:.2f}), NDCG@5: {:.2f} ± ({:.2f}), NDCG@10: {:.2f} ± ({:.2f})".
86-
format(lang,
87-
np.mean(f1k_scores), np.std(f1k_scores),
88-
np.mean(f1k_mt_scores), np.std(f1k_mt_scores),
89-
np.mean(f1k_domain_scores), np.std(f1k_domain_scores),
90-
np.mean(ndcg_1_scores), np.std(ndcg_1_scores),
91-
np.mean(ndcg_3_scores), np.std(ndcg_3_scores),
92-
np.mean(ndcg_5_scores), np.std(ndcg_5_scores),
93-
np.mean(ndcg_10_scores), np.std(ndcg_10_scores)))
92+
format(lang,
93+
np.mean(f1k_scores), np.std(f1k_scores),
94+
np.mean(f1k_mt_scores), np.std(f1k_mt_scores),
95+
np.mean(f1k_domain_scores), np.std(f1k_domain_scores),
96+
np.mean(pk_scores), np.std(pk_scores),
97+
np.mean(rk_scores), np.std(rk_scores),
98+
np.mean(ndcg_1_scores), np.std(ndcg_1_scores),
99+
np.mean(ndcg_3_scores), np.std(ndcg_3_scores),
100+
np.mean(ndcg_5_scores), np.std(ndcg_5_scores),
101+
np.mean(ndcg_10_scores), np.std(ndcg_10_scores)))
94102

95103

96104
if __name__ == "__main__":
97105
parser = argparse.ArgumentParser()
98-
parser.add_argument("--config", type=str, default="pyeurovoc/configs/models.yml", help="Tokenizer used for each language.")
99-
parser.add_argument("--mt_labels", type=str, default="pyeurovoc/resources/mt_labels.json")
106+
parser.add_argument("--config", type=str, default="configs/models.yml", help="Tokenizer used for each language.")
107+
parser.add_argument("--mt_labels", type=str, default="configs/mt_labels.json")
100108
parser.add_argument("--data_path", type=str, default="data/eurovoc", help="Path to the EuroVoc data.")
101109
parser.add_argument("--device", type=str, default="cpu", help="Device to train on.")
102110
parser.add_argument("--models_path", type=str, default="models", help="Path of the saved models.")

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# $ pip install sampleproject
1111
name='pyeurovoc', # Required
1212

13-
version='0.2.0', # Required
13+
version='1.0.0', # Required
1414

1515
description='Python API for multilingual legal document classification with EuroVoc descriptors using BERT models.', # Required
1616

@@ -27,7 +27,7 @@
2727
# 3 - Alpha
2828
# 4 - Beta
2929
# 5 - Production/Stable
30-
'Development Status :: 4 - Beta',
30+
'Development Status :: 5 - Production/Stable',
3131

3232
'Intended Audience :: Developers',
3333
'Intended Audience :: Education',

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def train():
106106

107107
if __name__ == '__main__':
108108
parser = argparse.ArgumentParser()
109-
parser.add_argument("--config", type=str, default="pyeurovoc/configs/models.yml", help="Tokenizer used for each language.")
109+
parser.add_argument("--config", type=str, default="configs/models.yml", help="Tokenizer used for each language.")
110110
parser.add_argument("--data_path", type=str, default="data/eurovoc", help="Path to the EuroVoc data.")
111111
parser.add_argument("--epochs", type=int, default=1, help="Number of epochs to train the model.")
112112
parser.add_argument("--batch_size", type=int, default=2, help="Batch size of the dataset.")

0 commit comments

Comments
 (0)