Skip to content

Commit

Permalink
Merge pull request #65 from NeotomaDB/22-fine-tune-allenai-specter2-m…
Browse files Browse the repository at this point in the history
…odel-for-ner

HuggingFace Model Training
  • Loading branch information
brabbit61 committed Jun 27, 2023
2 parents 9679e20 + 8a7d6d3 commit 45f5e8a
Show file tree
Hide file tree
Showing 13 changed files with 309 additions and 377 deletions.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@

# python version 3.10
dash~=2.10
dash_bootstrap_components~=1.4
dash-testing-stub==0.0.2
Expand Down Expand Up @@ -34,6 +36,6 @@ seqeval==1.2.2
spacy~=3.5
torch~=1.12
tqdm~=4.65
transformers~=4.24
transformers~=4.28
# to use the spacy model for baseline NER
https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.5.0/en_core_web_lg-3.5.0-py3-none-any.whl
166 changes: 0 additions & 166 deletions sample_pipeline_output.json

This file was deleted.

18 changes: 10 additions & 8 deletions src/entity_extraction/hf_entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pandas as pd
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from torch import cuda
import torch

from src.logs import get_logger

Expand All @@ -32,22 +32,24 @@ def load_ner_model_pipeline(model_path: str):
The loaded tokenizer.
"""

device = "cuda" if cuda.is_available() else "cpu"
if device == "cuda":
logger.info("Using GPU for predictions, batch size of 8")
batch_size = 8
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
if "cuda" in device_str:
logger.info("Using GPU for predictions, batch size of 32")
batch_size = 32
else:
logger.info("Using CPU for predictions, batch size of 1")
batch_size = 1

# load the model
model = AutoModelForTokenClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512)
tokenizer = AutoTokenizer.from_pretrained(
model_path, model_max_length=512, padding=True, truncation=True
)
ner_pipe = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
device=device,
device=torch.device(device_str),
batch_size=batch_size,
aggregation_strategy="simple",
)
Expand Down Expand Up @@ -136,7 +138,7 @@ def get_predicted_labels(df, ner_pipe):
df["predicted_labels"] = predicted_labels

df[["split_text", "predicted_tokens"]] = df.apply(
lambda row: get_hf_token_labels(row.predicted_labels, " ".join(row.text)),
lambda row: get_hf_token_labels(row.predicted_labels, row.text),
axis="columns",
result_type="expand",
)
Expand Down
36 changes: 27 additions & 9 deletions src/entity_extraction/ner_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# ensure src is in the path
import sys
import os

SRC_DIR = os.path.join(os.path.dirname(__file__), "..", "..")
if SRC_DIR not in sys.path:
sys.path.append(SRC_DIR)
Expand All @@ -18,14 +19,23 @@


class Evaluator:
def __init__(self, true, pred, tags):
""" """

if len(true) != len(pred):
def __init__(self, true_labels, pred_labels, tags):
"""Initializer for the Evaluator class
Parameters
----------
true_labels : list
A list of true label strings of the form ["B-LOC", "I-LOC", "O"]
pred_labels : list
A list of predicted label strings of the form ["B-LOC", "I-LOC", "O"]
tags : list
A list of all possible tags, e.g. ["LOC", "PER", "ORG"]
"""
if len(true_labels) != len(true_labels):
raise ValueError("Number of predicted documents does not equal true")

self.true = true
self.pred = pred
self.true = true_labels
self.pred = pred_labels
self.tags = tags

# Setup dict into which metrics will be stored.
Expand Down Expand Up @@ -112,13 +122,21 @@ def evaluate(self):
return self.results, self.evaluation_agg_entities_type


def collect_named_entities(tokens):
def collect_named_entities(tokens: list) -> list[Entity]:
"""
Creates a list of Entity named-tuples, storing the entity type and the start and end
offsets of the entity.
:param tokens: a list of tags
:return: a list of Entity named-tuples
Parameters
----------
tokens : list
A list of tokens, where each token is of the form B-LOC
Returns
-------
named_entities : list
A list of Entity named-tuples, storing the entity type and the start and end
offsets of the entity in terms of B- I- tag counts
"""

named_entities = []
Expand Down
Loading

0 comments on commit 45f5e8a

Please sign in to comment.