diff --git a/deepmultilingualpunctuation/__init__.py b/deepmultilingualpunctuation/__init__.py index 5919050..a285e41 100644 --- a/deepmultilingualpunctuation/__init__.py +++ b/deepmultilingualpunctuation/__init__.py @@ -1 +1 @@ -from .punctuationmodel import PunctuationModel \ No newline at end of file +from .punctuationmodel import PunctuationModel diff --git a/deepmultilingualpunctuation/punctuationmodel.py b/deepmultilingualpunctuation/punctuationmodel.py index 14d11fa..e2b811b 100644 --- a/deepmultilingualpunctuation/punctuationmodel.py +++ b/deepmultilingualpunctuation/punctuationmodel.py @@ -1,79 +1,90 @@ -from concurrent.futures import process -from transformers import pipeline import re + import torch +from transformers import pipeline + + +class PunctuationModel: + def __init__( + self, + model="oliverguhr/fullstop-punctuation-multilang-large", + **kwargs, + ) -> None: + if "aggregation_strategy" not in kwargs: + kwargs["aggregation_strategy"] = "none" + self.pipe = pipeline("ner", model, **kwargs) -class PunctuationModel(): - def __init__(self, model = "oliverguhr/fullstop-punctuation-multilang-large") -> None: - if torch.cuda.is_available(): - self.pipe = pipeline("ner",model, aggregation_strategy="none", device=0) - else: - self.pipe = pipeline("ner",model, aggregation_strategy="none") - - def preprocess(self,text): - #remove markers except for markers in numbers - text = re.sub(r"(? result[result_index]["end"] : - label = result[result_index]['entity'] - score = result[result_index]['score'] - result_index += 1 - tagged_words.append([word,label, score]) - + while ( + result_index < len(result) + and char_index > result[result_index]["end"] + ): + label = result[result_index]["entity"] + score = result[result_index]["score"] + result_index += 1 + tagged_words.append([word, label, score]) + assert len(tagged_words) == len(words) return tagged_words - def prediction_to_text(self,prediction): + def prediction_to_text(self, prediction): result = "" for word, label, _ in prediction: result += word if label == "0": result += " " if label in ".,?-:": - result += label+" " + result += label + " " return result.strip() -if __name__ == "__main__": + +if __name__ == "__main__": model = PunctuationModel() text = "das , ist fies " diff --git a/setup.py b/setup.py index f5e01a0..fc0be03 100644 --- a/setup.py +++ b/setup.py @@ -12,15 +12,15 @@ long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/oliverguhr/deepmultilingualpunctuation", - packages=setuptools.find_packages(), + packages=setuptools.find_packages(), classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], install_requires=[ - "transformers", - "torch>=1.8.1", + "transformers", + "torch>=1.8.1", ], - python_requires='>=3.6', + python_requires=">=3.6", )