Skip to content

Commit

Permalink
#62 - Bugfix for FlairTagger when using certain input formats
Browse files Browse the repository at this point in the history
* The Whitespace fix was pretty hackish and was not working as intended with certain input formats. It would also limit the flair tagger to only tag around sentence boundaries. This commit replace the hack with a conversion from the CAS tokens to flair token objects which is much cleaner. Multiple whitespaces between two tokens can now also properly be processed.
  • Loading branch information
raykyn committed Apr 9, 2024
1 parent 9cca892 commit 02a7f7e
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions ariadne/contrib/flair.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,28 @@
from cassis import Cas

from flair.nn import Classifier as Tagger
from flair.data import Sentence
from flair.data import Sentence, Token

from ariadne.classifier import Classifier
from ariadne.contrib.inception_util import create_prediction, SENTENCE_TYPE, TOKEN_TYPE


def fix_whitespaces(cas_tokens):
tokens = []
for cas_token, following_cas_token in zip(cas_tokens, cas_tokens[1:] + [None]):
if following_cas_token is not None:
dist = following_cas_token.begin - cas_token.end
else:
dist = 1
token = Token(
cas_token.get_covered_text(),
whitespace_after=dist,
start_position=cas_token.begin
)
tokens.append(token)
return tokens


class FlairNERClassifier(Classifier):
def __init__(self, model_name: str, model_directory: Path = None, split_sentences: bool = True):
super().__init__(model_directory=model_directory)
Expand All @@ -33,43 +49,36 @@ def __init__(self, model_name: str, model_directory: Path = None, split_sentence
def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str):
# Extract the sentences from the CAS
if self._split_sentences:
sentences = []
cas_sents = cas.select(SENTENCE_TYPE)
sents = [Sentence(sent.get_covered_text(), use_tokenizer=False) for sent in cas_sents]
offsets = [sent.begin for sent in cas_sents]
for cas_sent in cas_sents:
# transform cas tokens to flair tokens with correct spacing
cas_tokens = cas.select_covered(TOKEN_TYPE, cas_sent)
tokens = fix_whitespaces(cas_tokens)
sentences.append(Sentence(tokens))

# Find the named entities
self._model.predict(sents)
self._model.predict(sentences)

for offset, sent in zip(offsets, sents):
for sentence in sentences:
# For every entity returned by spacy, create an annotation in the CAS
for named_entity in sent.to_dict()["entities"]:
begin = named_entity["start_pos"] + offset
end = named_entity["end_pos"] + offset
label = named_entity["labels"][0]["value"]
for named_entity in sentence.get_spans():
begin = named_entity.start_position
end = named_entity.end_position
label = named_entity.tag
prediction = create_prediction(cas, layer, feature, begin, end, label)
cas.add(prediction)

else:
cas_tokens = cas.select(TOKEN_TYPE)

# build sentence with correct whitespaces
# (when using sentences, this should not be a problem afaik)
text = ""
last_end = 0
for cas_token in cas_tokens:
if cas_token.begin == last_end:
text += cas_token.get_covered_text()
else:
text += " " + cas_token.get_covered_text()
last_end = cas_token.end

sent = Sentence(text, use_tokenizer=False)
text = fix_whitespaces(cas_tokens)
sent = Sentence(text)

self._model.predict(sent)

for named_entity in sent.to_dict()["entities"]:
begin = named_entity["start_pos"]
end = named_entity["end_pos"]
label = named_entity["labels"][0]["value"]
for named_entity in sent.get_spans():
begin = named_entity.start_position
end = named_entity.end_position
label = named_entity.tag
prediction = create_prediction(cas, layer, feature, begin, end, label)
cas.add(prediction)
cas.add(prediction)

0 comments on commit 02a7f7e

Please sign in to comment.