diff --git a/ariadne/contrib/transformers.py b/ariadne/contrib/transformers.py new file mode 100644 index 0000000..5280fae --- /dev/null +++ b/ariadne/contrib/transformers.py @@ -0,0 +1,42 @@ +# Licensed to the Technische Universität Darmstadt under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The Technische Universität Darmstadt +# licenses this file to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification +from ariadne.classifier import Classifier +from ariadne.contrib.inception_util import create_prediction +from cassis import Cas + +class TransformerNerClassifier(Classifier): + def __init__(self, model_name: str): + super().__init__() + # Load the Hugging Face model and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512) + self.model = AutoModelForTokenClassification.from_pretrained(model_name) + self.ner_pipeline = pipeline("ner", model=self.model, tokenizer=self.tokenizer, aggregation_strategy="first") + + + + def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str): + + document_text = cas.sofa_string + predictions = self.ner_pipeline(document_text) + for prediction in predictions: + start_char = prediction['start'] + end_char = prediction['end'] + label = prediction['entity_group'] + cas_prediction = create_prediction(cas, layer, feature, start_char, end_char, label) + cas.add(cas_prediction) + \ No newline at end of file diff --git a/setup.py b/setup.py index ad4645a..65f26da 100644 --- a/setup.py +++ b/setup.py @@ -51,14 +51,15 @@ "lightgbm~=4.2.0", "diskcache~=5.2.1", "simalign~=0.4", - "flair>=0.13.1" + "flair>=0.13.1", + "transformers[torch]~=4.41.1", # TransformerNerClassifier ] test_dependencies = [ "tox", "pytest", "codecov", - "pytest-cov", + "pytest-cov", ] dev_dependencies = [ diff --git a/tests/test_transformer_recommender.py b/tests/test_transformer_recommender.py new file mode 100644 index 0000000..54c5bff --- /dev/null +++ b/tests/test_transformer_recommender.py @@ -0,0 +1,34 @@ +# Licensed to the Technische Universität Darmstadt under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The Technische Universität Darmstadt +# licenses this file to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +pytest.importorskip("transformers") + +from ariadne.contrib.transformers import TransformerNerClassifier +from tests.util import load_obama, PREDICTED_TYPE, PREDICTED_FEATURE, PROJECT_ID, USER + + +def test_predict_ner(tmpdir_factory): + cas = load_obama() + sut = TransformerNerClassifier("lfcc/lusa_events") + + sut.predict(cas, PREDICTED_TYPE, PREDICTED_FEATURE, PROJECT_ID, "doc_42", USER) + predictions = list(cas.select(PREDICTED_TYPE)) + + assert len(predictions) + + for prediction in predictions: + assert getattr(prediction, PREDICTED_FEATURE) is not None \ No newline at end of file