Skip to content

Commit

Permalink
Added Working EuroVocBERT class
Browse files Browse the repository at this point in the history
  • Loading branch information
avramandrei committed Jul 31, 2021
1 parent 39a36c7 commit 9b29d11
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 20 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ dmypy.json

# Pyre type checker
.pyre/
data/tmp
data/tmp
data/test
File renamed without changes.
File renamed without changes.
94 changes: 82 additions & 12 deletions pyeurovoc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,112 @@
import os
import torch
import yaml
import pickle
import json
from transformers import AutoTokenizer


PYEUROVOC_PATH = os.path.join(os.path.expanduser("~"), ".cache", "pyeurovoc")
REPOSITORY_URL = ""

DICT_MODELS = {
"bg": "TurkuNLP/wikibert-base-bg-cased",
"cs": "TurkuNLP/wikibert-base-cs-cased",
"da": "Maltehb/danish-bert-botxo",
"de": "bert-base-german-cased",
"el": "nlpaueb/bert-base-greek-uncased-v1",
"en": "nlpaueb/legal-bert-base-uncased",
"es": "dccuchile/bert-base-spanish-wwm-cased",
"et": "tartuNLP/EstBERT",
"fi": "TurkuNLP/bert-base-finnish-cased-v1",
"fr": "camembert-base",
"hu": "SZTAKI-HLT/hubert-base-cc",
"it": "dbmdz/bert-base-italian-cased",
"lt": "TurkuNLP/wikibert-base-lt-cased",
"lv": "TurkuNLP/wikibert-base-lv-cased",
"mt": "bert-base-multilingual-cased",
"nl": "wietsedv/bert-base-dutch-cased",
"pl": "dkleczek/bert-base-polish-cased-v1",
"pt": "neuralmind/bert-base-portuguese-cased",
"ro": "dumitrescustefan/bert-base-romanian-cased-v1",
"sk": "TurkuNLP/wikibert-base-sk-cased",
"sl": "TurkuNLP/wikibert-base-sl-cased",
"sv": "KB/bert-base-swedish-cased"
}


class EuroVocBERT:
def __init__(self, lang="en"):
if lang not in DICT_MODELS.keys():
raise ValueError("Language parameter must be one of the following languages: {}".format(DICT_MODELS.keys()))

if not os.path.exists(PYEUROVOC_PATH):
os.makedirs(PYEUROVOC_PATH)

# model must be downloaded from the repostiory
if not os.path.exists(os.path.join(PYEUROVOC_PATH, f"model_{lang}.pt")):
print(f"Model 'model_{lang}.pt' not found in the .cache directory at '{PYEUROVOC_PATH}'")
print(f"Downloading 'model_{lang}.pt from {REPOSITORY_URL}...")
print(f"Model 'model_{lang}.pt' not found in the .cache directory at '{PYEUROVOC_PATH}'. "
f"Downloading from '{REPOSITORY_URL}'...")
# model already exists, loading from .cache directory
else:
print(f"Model 'model_{lang}.pt' found in the .cache directory at '{PYEUROVOC_PATH}'")
print("Loading model...")
print(f"Model 'model_{lang}.pt' found in the .cache directory at '{PYEUROVOC_PATH}'. "
f"Loading...")

# load the model
self.model = torch.load(os.path.join(PYEUROVOC_PATH, f"model_{lang}.pt"))

# load the model dictionary (e.g. language -> bert_model)
with open(os.path.join("configs", "models.yml"), "r") as yml_file:
dict_models = yaml.load(yml_file)
# load the multi-label encoder for eurovoc, y, download from repository if not found in .cache directory
if not os.path.exists(os.path.join(PYEUROVOC_PATH, f"mlb_encoder_{lang}.pickle")):
print(f"Label encoder 'mlb_encoder_{lang}.pickle' not found in the .cache directory at '{PYEUROVOC_PATH}'."
f" Downloading from '{REPOSITORY_URL}'...")
else:
print(f"Label encoder 'mlb_encoder_{lang}.pickle' found in the .cache directory at '{PYEUROVOC_PATH}'."
f" Loading...")

with open(os.path.join(PYEUROVOC_PATH, f"mlb_encoder_{lang}.pickle"), "rb") as pck_file:
self.mlb_encoder = pickle.load(pck_file)

# load MT descriptors dictionary, download from repository if not found in .cache directory
if not os.path.exists(os.path.join(PYEUROVOC_PATH, "mt_labels.json")):
print(f"MT descriptors dictionary 'mt_labels.json' not found in the .cache directory at '{PYEUROVOC_PATH}'."
f" Downloading from '{REPOSITORY_URL}'...")
else:
print(f"MT descriptors dictionary 'mt_labels.json' found in the .cache directory at '{PYEUROVOC_PATH}'. "
f"Loading...")

with open(os.path.join(PYEUROVOC_PATH, "mt_labels.json"), "r") as json_file:
self.dict_mt_labels = json.load(json_file)

# load the tokenizer according to the model dictionary
self.tokenizer = AutoTokenizer.from_pretrained(dict_models[lang])
self.tokenizer = AutoTokenizer.from_pretrained(DICT_MODELS[lang])

def __call__(self, document_text, num_id_labels=6, num_mt_labels=5, num_do_labels=4):
encoding_ids = self.tokenizer.encode(
def __call__(self, document_text, num_id_labels=6):
input_ids = self.tokenizer.encode(
document_text,
return_attention_mask=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
).reshape(1, -1)

with torch.no_grad():
logits = self.model(
input_ids,
torch.ones_like(input_ids)
)[0]

probs = torch.sigmoid(logits).detach().cpu()

probs_sorted, idx = torch.sort(probs, descending=True)

outputs = torch.zeros_like(logits)
outputs[idx[:num_id_labels]] = 1

id_labels = self.mlb_encoder.inverse_transform(outputs.reshape(1, -1))[0]
id_probs = probs[idx[:num_id_labels]]

result = {}

for id_label, id_prob in zip(id_labels, id_probs):
result[str(id_label)] = float(id_prob)

return result
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# $ pip install sampleproject
name='pyeurovoc', # Required

version='0.0.0', # Required
version='0.1.0', # Required

description='Python API for multilingual legal document classification with EuroVoc descriptors using BERT models.', # Required

Expand All @@ -27,7 +27,7 @@
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 3 - Production/Stable',
'Development Status :: 3 - Alpha',

'Intended Audience :: Developers',
'Intended Audience :: Education',
Expand All @@ -51,9 +51,9 @@
keywords='eurovoc bert legal document classification', # Optional

# packages=find_packages(exclude=['jupyter']), # Required
packages=find_packages("pyeurovoc"), # Required
packages=find_packages("."), # Required

install_requires=['transformers', 'sklearn', 'torch', 'scikit-multilearn', 'pyyaml', 'waitress', 'flask'], # Optional

zip_safe=False
zip_safe=False,
)
3 changes: 0 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def f1k_mt_scores(self, y_true, probs, eps=1e-10):
true_labels_domain = [label[label != 0] for label in true_labels_domain]
pred_labels_domain = [label[label != 0] for label in pred_labels_domain]

# print(true_labels_mt)
# print(pred_labels_mt)

pk_mt_scores = [np.intersect1d(true, pred).shape[0] / pred.shape[0] + eps if pred.shape[0] != 0 else eps for true, pred in
zip(true_labels_mt, pred_labels_mt)]
rk_mt_scores = [np.intersect1d(true, pred).shape[0] / true.shape[0] + eps if pred.shape[0] != 0 else eps for true, pred in
Expand Down

0 comments on commit 9b29d11

Please sign in to comment.