Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Cleanup, Match accronyms, expose pipeline kwargs #21

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmultilingualpunctuation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .punctuationmodel import PunctuationModel
from .punctuationmodel import PunctuationModel
83 changes: 47 additions & 36 deletions deepmultilingualpunctuation/punctuationmodel.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,90 @@
from concurrent.futures import process
from transformers import pipeline
import re

import torch
from transformers import pipeline


class PunctuationModel:
def __init__(
self,
model="oliverguhr/fullstop-punctuation-multilang-large",
**kwargs,
) -> None:
if "aggregation_strategy" not in kwargs:
kwargs["aggregation_strategy"] = "none"
self.pipe = pipeline("ner", model, **kwargs)

class PunctuationModel():
def __init__(self, model = "oliverguhr/fullstop-punctuation-multilang-large") -> None:
if torch.cuda.is_available():
self.pipe = pipeline("ner",model, aggregation_strategy="none", device=0)
else:
self.pipe = pipeline("ner",model, aggregation_strategy="none")

def preprocess(self,text):
#remove markers except for markers in numbers
text = re.sub(r"(?<!\d)[.,;:!?](?!\d)","",text)
#todo: match acronyms https://stackoverflow.com/questions/35076016/regex-to-match-acronyms
def preprocess(self, text):
# remove punctuation except dots
text = re.sub(r"[,;:!?]", "", text)
# remove dots that are not in acronyms or decimal points
text = re.sub(r"(?<!\b[a-zA-Z])(?<!\d)\.", "", text)
text = text.split()
return text

def restore_punctuation(self, text, chunk_size=230):
def restore_punctuation(self, text, chunk_size=230):
result = self.predict(self.preprocess(text), chunk_size)
return self.prediction_to_text(result)
def overlap_chunks(self,lst, n, stride=0):

def overlap_chunks(self, lst, n, stride=0):
"""Yield successive n-sized chunks from lst with stride length of overlap."""
for i in range(0, len(lst), n-stride):
yield lst[i:i + n]
for i in range(0, len(lst), n - stride):
yield lst[i : i + n]

def predict(self, words, chunk_size=230):
overlap = 5
if len(words) <= chunk_size:
overlap = 0

batches = list(self.overlap_chunks(words,chunk_size,overlap))
batches = list(self.overlap_chunks(words, chunk_size, overlap))

# if the last batch is smaller than the overlap,
# if the last batch is smaller than the overlap,
# we can just remove it
if len(batches[-1]) <= overlap:
batches.pop()

tagged_words = []
tagged_words = []
for batch in batches:
# use last batch completely
if batch == batches[-1]:
if batch == batches[-1]:
overlap = 0
text = " ".join(batch)
result = self.pipe(text)
assert len(text) == result[-1]["end"], "chunk size too large, text got clipped"

result = self.pipe(text)
assert (
len(text) == result[-1]["end"]
), "chunk size too large, text got clipped"

char_index = 0
result_index = 0
for word in batch[:len(batch)-overlap]:
for word in batch[: len(batch) - overlap]:
char_index += len(word) + 1
# if any subtoken of an word is labled as sentence end
# we label the whole word as sentence end
# we label the whole word as sentence end
label = "0"
while result_index < len(result) and char_index > result[result_index]["end"] :
label = result[result_index]['entity']
score = result[result_index]['score']
result_index += 1
tagged_words.append([word,label, score])

while (
result_index < len(result)
and char_index > result[result_index]["end"]
):
label = result[result_index]["entity"]
score = result[result_index]["score"]
result_index += 1
tagged_words.append([word, label, score])

assert len(tagged_words) == len(words)
return tagged_words

def prediction_to_text(self,prediction):
def prediction_to_text(self, prediction):
result = ""
for word, label, _ in prediction:
result += word
if label == "0":
result += " "
if label in ".,?-:":
result += label+" "
result += label + " "
return result.strip()

if __name__ == "__main__":

if __name__ == "__main__":
model = PunctuationModel()

text = "das , ist fies "
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/oliverguhr/deepmultilingualpunctuation",
packages=setuptools.find_packages(),
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
install_requires=[
"transformers",
"torch>=1.8.1",
"transformers",
"torch>=1.8.1",
],
python_requires='>=3.6',
python_requires=">=3.6",
)