Skip to content

Commit

Permalink
fix the transformer classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
lfcc1 committed Sep 20, 2024
1 parent 7a3b5c3 commit bd88246
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions ariadne/contrib/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
from ariadne.classifier import Classifier
from ariadne.contrib.inception_util import create_prediction, SENTENCE_TYPE
from ariadne.contrib.inception_util import create_prediction
from cassis import Cas

class TransformerNerClassifier(Classifier):
Expand All @@ -31,13 +31,12 @@ def __init__(self, model_name: str):

def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str):

#document_text = cas.sofa_string
for sentence in cas.select(SENTENCE_TYPE):
predictions = self.ner_pipeline(sentence)
for prediction in predictions:
start_char = prediction['start']
end_char = prediction['end']
label = prediction['entity_group']
cas_prediction = create_prediction(cas, layer, feature, start_char, end_char, label)
cas.add(cas_prediction)
document_text = cas.sofa_string
predictions = self.ner_pipeline(document_text)
for prediction in predictions:
start_char = prediction['start']
end_char = prediction['end']
label = prediction['entity_group']
cas_prediction = create_prediction(cas, layer, feature, start_char, end_char, label)
cas.add(cas_prediction)

0 comments on commit bd88246

Please sign in to comment.