Skip to content

Commit

Permalink
#55 - Add FlairNLP Sequence Tagging
Browse files Browse the repository at this point in the history
* Implemented script to use the Flair SequenceTagger.
* Added option to not use sentences to the script.
* Fixed the whitespace problem!
* Added licenser header. Removed debug statement.
* Added class to contrib models table.
* Upgraded dkpro-cassis to 0.9.1 to ensure dependency compatibility with flair 0.13.1 (both can use more-itertools 0.8.14 now)
  • Loading branch information
raykyn authored Mar 5, 2024
1 parent d4fea9a commit 23d7d8f
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ following table provides an overview about them:
| SpacyPosClassifier | Part-of-speech prediction with [spaCy](https://spacy.io/) | no |
| AdapterSequenceTagger | Sequence tagger using [Adapters](https://adapterhub.ml/) | no |
| AdapterSentenceClassifier | Sentence classifier using [Adapters](https://adapterhub.ml/) | no |
| FlairNERClassifier | Sequence tagger using [Flair](https://flairnlp.github.io/) | no |

For using trainable recommenders it is important to check the checkbox *Trainable* when adding
the external recommender to your project. To be able to get predictions of a added trainable
Expand Down
75 changes: 75 additions & 0 deletions ariadne/contrib/flair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Licensed to the Technische Universität Darmstadt under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The Technische Universität Darmstadt
# licenses this file to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path

from cassis import Cas

from flair.nn import Classifier as Tagger
from flair.data import Sentence

from ariadne.classifier import Classifier
from ariadne.contrib.inception_util import create_prediction, SENTENCE_TYPE, TOKEN_TYPE


class FlairNERClassifier(Classifier):
def __init__(self, model_name: str, model_directory: Path = None, split_sentences: bool = True):
super().__init__(model_directory=model_directory)
self._model = Tagger.load(model_name)
self._split_sentences = split_sentences

def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str):
# Extract the sentences from the CAS
if self._split_sentences:
cas_sents = cas.select(SENTENCE_TYPE)
sents = [Sentence(sent.get_covered_text(), use_tokenizer=False) for sent in cas_sents]
offsets = [sent.begin for sent in cas_sents]

# Find the named entities
self._model.predict(sents)

for offset, sent in zip(offsets, sents):
# For every entity returned by spacy, create an annotation in the CAS
for named_entity in sent.to_dict()["entities"]:
begin = named_entity["start_pos"] + offset
end = named_entity["end_pos"] + offset
label = named_entity["labels"][0]["value"]
prediction = create_prediction(cas, layer, feature, begin, end, label)
cas.add(prediction)

else:
cas_tokens = cas.select(TOKEN_TYPE)

# build sentence with correct whitespaces
# (when using sentences, this should not be a problem afaik)
text = ""
last_end = 0
for cas_token in cas_tokens:
if cas_token.begin == last_end:
text += cas_token.get_covered_text()
else:
text += " " + cas_token.get_covered_text()
last_end = cas_token.end

sent = Sentence(text, use_tokenizer=False)

self._model.predict(sent)

for named_entity in sent.to_dict()["entities"]:
begin = named_entity["start_pos"]
end = named_entity["end_pos"]
label = named_entity["labels"][0]["value"]
prediction = create_prediction(cas, layer, feature, begin, end, label)
cas.add(prediction)
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
HOMEPAGE = "https://inception-project.github.io/"
EMAIL = "inception-users@googlegroups.com"
AUTHOR = "The INCEpTION team"
REQUIRES_PYTHON = ">=3.6.0"
REQUIRES_PYTHON = ">=3.8.0"

install_requires = [
"flask",
"filelock",
"dkpro-cassis>=0.7.6",
"dkpro-cassis>=0.9.1",
"joblib",
"gunicorn",
"deprecation",
Expand All @@ -50,7 +50,8 @@
"sentence-transformers~=2.2.2",
"lightgbm~=4.2.0",
"diskcache~=5.2.1",
"simalign~=0.4"
"simalign~=0.4",
"flair>=0.13.1"
]

test_dependencies = [
Expand Down

0 comments on commit 23d7d8f

Please sign in to comment.