diff --git a/ariadne/contrib/flair.py b/ariadne/contrib/flair.py index d36070f..c8bd9cd 100644 --- a/ariadne/contrib/flair.py +++ b/ariadne/contrib/flair.py @@ -18,12 +18,28 @@ from cassis import Cas from flair.nn import Classifier as Tagger -from flair.data import Sentence +from flair.data import Sentence, Token from ariadne.classifier import Classifier from ariadne.contrib.inception_util import create_prediction, SENTENCE_TYPE, TOKEN_TYPE +def fix_whitespaces(cas_tokens): + tokens = [] + for cas_token, following_cas_token in zip(cas_tokens, cas_tokens[1:] + [None]): + if following_cas_token is not None: + dist = following_cas_token.begin - cas_token.end + else: + dist = 1 + token = Token( + cas_token.get_covered_text(), + whitespace_after=dist, + start_position=cas_token.begin + ) + tokens.append(token) + return tokens + + class FlairNERClassifier(Classifier): def __init__(self, model_name: str, model_directory: Path = None, split_sentences: bool = True): super().__init__(model_directory=model_directory) @@ -33,43 +49,36 @@ def __init__(self, model_name: str, model_directory: Path = None, split_sentence def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str): # Extract the sentences from the CAS if self._split_sentences: + sentences = [] cas_sents = cas.select(SENTENCE_TYPE) - sents = [Sentence(sent.get_covered_text(), use_tokenizer=False) for sent in cas_sents] - offsets = [sent.begin for sent in cas_sents] + for cas_sent in cas_sents: + # transform cas tokens to flair tokens with correct spacing + cas_tokens = cas.select_covered(TOKEN_TYPE, cas_sent) + tokens = fix_whitespaces(cas_tokens) + sentences.append(Sentence(tokens)) # Find the named entities - self._model.predict(sents) + self._model.predict(sentences) - for offset, sent in zip(offsets, sents): + for sentence in sentences: # For every entity returned by spacy, create an annotation in the CAS - for named_entity in sent.to_dict()["entities"]: - begin = named_entity["start_pos"] + offset - end = named_entity["end_pos"] + offset - label = named_entity["labels"][0]["value"] + for named_entity in sentence.get_spans(): + begin = named_entity.start_position + end = named_entity.end_position + label = named_entity.tag prediction = create_prediction(cas, layer, feature, begin, end, label) cas.add(prediction) else: cas_tokens = cas.select(TOKEN_TYPE) - - # build sentence with correct whitespaces - # (when using sentences, this should not be a problem afaik) - text = "" - last_end = 0 - for cas_token in cas_tokens: - if cas_token.begin == last_end: - text += cas_token.get_covered_text() - else: - text += " " + cas_token.get_covered_text() - last_end = cas_token.end - - sent = Sentence(text, use_tokenizer=False) + text = fix_whitespaces(cas_tokens) + sent = Sentence(text) self._model.predict(sent) - for named_entity in sent.to_dict()["entities"]: - begin = named_entity["start_pos"] - end = named_entity["end_pos"] - label = named_entity["labels"][0]["value"] + for named_entity in sent.get_spans(): + begin = named_entity.start_position + end = named_entity.end_position + label = named_entity.tag prediction = create_prediction(cas, layer, feature, begin, end, label) - cas.add(prediction) + cas.add(prediction)