diff --git a/pyeurovoc/__init__.py b/pyeurovoc/__init__.py index 401c367..427d943 100644 --- a/pyeurovoc/__init__.py +++ b/pyeurovoc/__init__.py @@ -3,6 +3,7 @@ import pickle from transformers import AutoTokenizer, BertTokenizer from .util import download_file +import re PYEUROVOC_PATH = os.path.join(os.path.expanduser("~"), ".cache", "pyeurovoc") @@ -58,6 +59,7 @@ def __init__(self, lang="en"): # load the model self.model = torch.load(os.path.join(PYEUROVOC_PATH, f"model_{lang}.pt")) + self.model.eval() # load the multi-label encoder for eurovoc, y, download from repository if not found in .cache directory if not os.path.exists(os.path.join(PYEUROVOC_PATH, f"mlb_encoder_{lang}.pickle")): @@ -82,6 +84,10 @@ def __init__(self, lang="en"): self.tokenizer = AutoTokenizer.from_pretrained(DICT_MODELS[lang]) def __call__(self, document_text, num_labels=6): + document_text = re.sub(r"<.*?>", "", document_text) + document_text = re.sub(r"\s+", " ", document_text) + document_text = document_text.strip() + input_ids = self.tokenizer.encode( document_text, return_attention_mask=True, diff --git a/setup.py b/setup.py index bf545ea..c2502b8 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ # $ pip install sampleproject name='pyeurovoc', # Required - version='1.0.3', # Required + version='1.0.4', # Required description='Python API for multilingual legal document classification with EuroVoc descriptors using BERT models.', # Required