Skip to content

Commit

Permalink
Fix bug in inferece. Added text cleaning.
Browse files Browse the repository at this point in the history
  • Loading branch information
avramandrei committed Aug 15, 2021
1 parent 71ee21b commit 6bf43fc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions pyeurovoc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
from transformers import AutoTokenizer, BertTokenizer
from .util import download_file
import re


PYEUROVOC_PATH = os.path.join(os.path.expanduser("~"), ".cache", "pyeurovoc")
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(self, lang="en"):

# load the model
self.model = torch.load(os.path.join(PYEUROVOC_PATH, f"model_{lang}.pt"))
self.model.eval()

# load the multi-label encoder for eurovoc, y, download from repository if not found in .cache directory
if not os.path.exists(os.path.join(PYEUROVOC_PATH, f"mlb_encoder_{lang}.pickle")):
Expand All @@ -82,6 +84,10 @@ def __init__(self, lang="en"):
self.tokenizer = AutoTokenizer.from_pretrained(DICT_MODELS[lang])

def __call__(self, document_text, num_labels=6):
document_text = re.sub(r"<.*?>", "", document_text)
document_text = re.sub(r"\s+", " ", document_text)
document_text = document_text.strip()

input_ids = self.tokenizer.encode(
document_text,
return_attention_mask=True,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# $ pip install sampleproject
name='pyeurovoc', # Required

version='1.0.3', # Required
version='1.0.4', # Required

description='Python API for multilingual legal document classification with EuroVoc descriptors using BERT models.', # Required

Expand Down

0 comments on commit 6bf43fc

Please sign in to comment.