diff --git a/ariadne/contrib/transformers.py b/ariadne/contrib/transformers.py index 158fded..5280fae 100644 --- a/ariadne/contrib/transformers.py +++ b/ariadne/contrib/transformers.py @@ -16,7 +16,7 @@ from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification from ariadne.classifier import Classifier -from ariadne.contrib.inception_util import create_prediction, SENTENCE_TYPE +from ariadne.contrib.inception_util import create_prediction from cassis import Cas class TransformerNerClassifier(Classifier): @@ -31,13 +31,12 @@ def __init__(self, model_name: str): def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str): - #document_text = cas.sofa_string - for sentence in cas.select(SENTENCE_TYPE): - predictions = self.ner_pipeline(sentence) - for prediction in predictions: - start_char = prediction['start'] - end_char = prediction['end'] - label = prediction['entity_group'] - cas_prediction = create_prediction(cas, layer, feature, start_char, end_char, label) - cas.add(cas_prediction) + document_text = cas.sofa_string + predictions = self.ner_pipeline(document_text) + for prediction in predictions: + start_char = prediction['start'] + end_char = prediction['end'] + label = prediction['entity_group'] + cas_prediction = create_prediction(cas, layer, feature, start_char, end_char, label) + cas.add(cas_prediction) \ No newline at end of file