Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retrain with parquet files #87

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,22 @@
labelled_chunks = []

for file in os.listdir(data_folder):
# if file doesn't end with txt skip it
if not file.endswith(".txt"):
continue

with open(os.path.join(data_folder, file), "r") as f:
task = json.load(f)

try:
raw_text = task["task"]["data"]["text"]
annotation_result = task["result"]
gdd_id = task["task"]["data"]["gdd_id"]

if file.endswith(".txt"):
with open(os.path.join(data_folder, file), "r") as f:
task = json.load(f)
annotation_result = task["result"]
gdd_id = task["task"]["data"]["gdd_id"]
raw_text = task["task"]["data"]["text"]
elif file.endswith(".json"):
with open(os.path.join(data_folder, file), "r") as f:
task = json.load(f)
annotation_result = task["result"]
gdd_id = task["data"]["gdd_id"]
raw_text = task["data"]["text"]

Check warning on line 89 in src/entity_extraction/training/hf_token_classification/huggingface_preprocess.py

View check run for this annotation

Codecov / codecov/patch

src/entity_extraction/training/hf_token_classification/huggingface_preprocess.py#L84-L89

Added lines #L84 - L89 were not covered by tests
else:
continue

Check warning on line 91 in src/entity_extraction/training/hf_token_classification/huggingface_preprocess.py

View check run for this annotation

Codecov / codecov/patch

src/entity_extraction/training/hf_token_classification/huggingface_preprocess.py#L91

Added line #L91 was not covered by tests

labelled_entities = [
annotation["value"] for annotation in annotation_result
]
Expand Down
3 changes: 1 addition & 2 deletions src/entity_extraction/training/spacy_ner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ This folder contains the training and evaluation scripts for the SpaCy Transform
## Training Workflow

A bash script is used to initialize a training job. Model training is fully customizable and users are encouraged to update the parameters in the `run_spacy_training.sh` and `spacy_transfomer_train.cfg` files prior to training. The training workflow is as follows:
1. Create a new data directory and dump all the TXT files (contains annotations in the JSONLines format) from Label Studio.
1. Create a new data directory and dump all the JSON files containing annotations from Label Studio and any reviewed parquet files.
2. Most parameters can be used with the default value, open the `run_spacy_training.sh` bash script and update the following fields with absolute paths or relative paths from the root of the repository:
- `DATA_PATH`: path to directory with Label Studio labelled data
- `DATA_OUTPUT_PATH`: path to directory to store the split dataset (train/val/test) as well as other data artifacts required for training.
- `MODEL_PATH`: If retraining, specify path to model artifacts. If training a model from scratch, pass empty string `""`
- `MODEL_OUTPUT_PATH`: path to store new model artifacts
- `VERSION`: Version can be updated to keep track of different training runs.
- `--gpu-id`: While executing the `spacy train` command, GPU can be used, if available, by setting this flag to **0**.
Expand Down
46 changes: 14 additions & 32 deletions src/entity_extraction/training/spacy_ner/run_spacy_training.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ echo "Current working directory: $(pwd)"

DATA_PATH="/path/to/sample input folder"
DATA_OUTPUT_PATH="/path/to/sample output folder"
MODEL_PATH="/path/to/model artifacts"
MODEL_OUTPUT_PATH="/path/to/new model artifacts"
VERSION="v1"
TRAIN_SPLIT=0.7
Expand All @@ -28,34 +27,17 @@ python3 src/preprocessing/labelling_data_split.py \

python3 src/preprocessing/spacy_preprocess.py --data_path $DATA_OUTPUT_PATH

if [ -z "$MODEL_PATH" ]; then
# If the model path is null, then start training from scratch

# Fill configuration with required fields
python -m spacy init fill-config \
src/entity_extraction/training/spacy_ner/spacy_transformer_train.cfg \
src/entity_extraction/training/spacy_ner/spacy_transformer_$VERSION.cfg

# Execute the training job by pointing to the new config file
python -m spacy train \
src/entity_extraction/training/spacy_ner/spacy_transformer_$VERSION.cfg \
--paths.train $DATA_OUTPUT_PATH/train.spacy \
--paths.dev $DATA_OUTPUT_PATH/val.spacy \
--output $MODEL_OUTPUT_PATH \
--gpu-id -1

else
# Else create a new config file to resume training
python src/entity_extraction/training/spacy_ner/create_config.py \
--model_path $MODEL_PATH \
--output_path src/entity_extraction/training/spacy_ner/spacy_transformer_$VERSION.cfg

python -m spacy train \
src/entity_extraction/training/spacy_ner/spacy_transformer_$VERSION.cfg \
--paths.train $DATA_OUTPUT_PATH/train.spacy \
--paths.dev $DATA_OUTPUT_PATH/val.spacy \
--components.ner.source $MODEL_PATH \
--components.transformer.source $MODEL_PATH \
--output $MODEL_OUTPUT_PATH \
--gpu-id -1
fi
# Start training from scratch

# Fill configuration with required fields
python -m spacy init fill-config \
src/entity_extraction/training/spacy_ner/spacy_transformer_train.cfg \
src/entity_extraction/training/spacy_ner/spacy_transformer_$VERSION.cfg

# Execute spacy CLI training
python -m spacy train \
src/entity_extraction/training/spacy_ner/spacy_transformer_$VERSION.cfg \
--paths.train $DATA_OUTPUT_PATH/train.spacy \
--paths.dev $DATA_OUTPUT_PATH/val.spacy \
--output $MODEL_OUTPUT_PATH \
--gpu-id -1
4 changes: 2 additions & 2 deletions src/preprocessing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ This script takes labelled dataset in JSONLines format as input and splits it in
The resulting train, validation, and test sets can be used for training and evaluating machine learning models.

#### **Options**
- `--raw_label_path=<raw_label_path>`: Specify the path to the directory where the raw label files are located.
- `--raw_label_path=<raw_label_path>`: Specify the path to the directory where the raw label files exported from LabelStudio and the parquet files containing the reviewed entities are located.

- `--output_path=<output_path>`: Specify the path to the directory where the output files will be written.

Expand Down Expand Up @@ -126,4 +126,4 @@ This script manages the creation of custom data artifacts required for training
4. Creates the custom data artifacts that can be used for training or fine-tuning spaCy models.

#### **Options**
- `--data_path=<data_path>`: Specify the path to the folder containing files in JSONLines format.
- `--data_path=<data_path>`: Specify the path to the folder containing JSON files in txt/json format.
143 changes: 118 additions & 25 deletions src/preprocessing/labelling_data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
import numpy as np
import shutil
import json

from collections import defaultdict
from datetime import datetime
from docopt import docopt

sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))

from src.logs import get_logger
logger = get_logger(__name__)
from src.preprocessing.labelling_preprocessing import get_hash

logger = get_logger(__name__)

def separate_labels_to_train_val_test(
labelled_file_path: str,
Expand Down Expand Up @@ -74,6 +76,9 @@
os.makedirs(os.path.join(output_path, "val"), exist_ok=True)
os.makedirs(os.path.join(output_path, "test"), exist_ok=True)

# Checks for parquet files and extracts them
extract_parquet_file(labelled_file_path)

gdd_ids = get_article_gdd_ids(labelled_file_path)

logger.info(f"Found {len(gdd_ids)} unique GDD IDs in the labelled data.")
Expand Down Expand Up @@ -156,20 +161,24 @@
},
}

# iterate through the files in the folder and convert them to the hf format
for file in os.listdir(labelled_file_path):
# if file doesn't end with txt skip it
if not file.endswith(".txt"):
continue

with open(os.path.join(labelled_file_path, file), "r") as f:
task = json.load(f)

try:
gdd_id = task["task"]["data"]["gdd_id"]
raw_text = task["task"]["data"]["text"]
annotation_result = task["result"]

try:
if file.endswith(".txt"):
with open(os.path.join(labelled_file_path, file), "r") as f:
task = json.load(f)
annotation_result = task["result"]
gdd_id = task["task"]["data"]["gdd_id"]
raw_text = task["task"]["data"]["text"]
elif file.endswith(".json"):
with open(os.path.join(labelled_file_path, file), "r") as f:
task = json.load(f)
annotation_result = task["result"]
gdd_id = task["data"]["gdd_id"]
raw_text = task["data"]["text"]

Check warning on line 178 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L174-L178

Added lines #L174 - L178 were not covered by tests
else:
continue

# get the number of words in the article
num_words = len(raw_text.split())

Expand Down Expand Up @@ -229,8 +238,16 @@
json.dump(data_metrics, f, indent=2)

logger.info("Finished separating files into train, val and test sets.")


logger.info(
f"Found {data_metrics['train']['entity_counts']} entities in {data_metrics['train']['article_count']} articles in train set."
)
logger.info(
f"Found {data_metrics['val']['entity_counts']} entities in {data_metrics['val']['article_count']} articles in val set."
)
logger.info(
f"Found {data_metrics['test']['entity_counts']} entities in {data_metrics['test']['article_count']} articles in test set."
)

def get_article_gdd_ids(labelled_file_path: str):
"""
Parameters
Expand All @@ -256,24 +273,100 @@

# iterate through the files and get the unique gdd_ids
gdd_ids = []

for file in os.listdir(labelled_file_path):
# if file doesn't end with txt skip it
if not file.endswith(".txt"):
continue

with open(os.path.join(labelled_file_path, file), "r") as f:
task = json.load(f)


try:
gdd_id = task["task"]["data"]["gdd_id"]
if file.endswith(".txt"):
with open(os.path.join(labelled_file_path, file), "r") as f:
task = json.load(f)
gdd_id = task["task"]["data"]["gdd_id"]
elif file.endswith(".json"):
with open(os.path.join(labelled_file_path, file), "r") as f:
task = json.load(f)
gdd_id = task["data"]["gdd_id"]

Check warning on line 287 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L285-L287

Added lines #L285 - L287 were not covered by tests
else:
continue
except Exception as e:
logger.warning(f"Issue with file data: {file}, {e}")

continue

Check warning on line 292 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L292

Added line #L292 was not covered by tests

if gdd_id not in gdd_ids:
gdd_ids.append(gdd_id)

return gdd_ids

def extract_parquet_file(labelled_file_path: str):
"""Checks the directory for parquet files and extracts the corrected entities

Parameter
---------
labelled_file_path: str
Directory containing the data files
"""

files = os.listdir(labelled_file_path)

# Iterate through the files and check if they are parquet files
for fin in files:
if fin.endswith(".parquet"):
df = pd.read_parquet(os.path.join(labelled_file_path, fin))

Check warning on line 313 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L313

Added line #L313 was not covered by tests

logger.info(f"Read parquet file {fin} with {len(df)} rows.")

Check warning on line 315 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L315

Added line #L315 was not covered by tests

for index, row in df.iterrows():

Check warning on line 317 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L317

Added line #L317 was not covered by tests

output_files = defaultdict(list)
all_sentences = {}
gdd_id = row["gddid"]
if row["corrected_entities"] != "None":

Check warning on line 322 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L319-L322

Added lines #L319 - L322 were not covered by tests

logger.info(f"Entities found in xDD ID: {gdd_id}")

Check warning on line 324 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L324

Added line #L324 was not covered by tests

corrected_entities = json.loads(row["corrected_entities"])

Check warning on line 326 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L326

Added line #L326 was not covered by tests

for ent_type in corrected_entities.keys():
for entity in corrected_entities[ent_type].keys():
if corrected_entities[ent_type][entity]['corrected_name']:
entity_text = corrected_entities[ent_type][entity]['corrected_name']

Check warning on line 331 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L328-L331

Added lines #L328 - L331 were not covered by tests
else:
entity_text = entity
for sentence in corrected_entities[ent_type][entity]['sentence']:
if (sentence['char_index']['start'] != -1 and

Check warning on line 335 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L333-L335

Added lines #L333 - L335 were not covered by tests
sentence['char_index']['end'] != -1):
all_sentences[sentence['sentid']] = sentence['text']
output_files[sentence['sentid']].append({

Check warning on line 338 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L337-L338

Added lines #L337 - L338 were not covered by tests
"value": {
"text": entity_text,
"start": sentence['char_index']['start'],
"end": sentence['char_index']['end'],
"labels": [ent_type]
}
})

logger.info(f"Number of sentences extracted for training: {len(output_files)}")

Check warning on line 347 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L347

Added line #L347 was not covered by tests

# Iterate through each sentence and create a json file
for sentid in output_files.keys():
text = all_sentences[sentid]
article_data = {

Check warning on line 352 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L350-L352

Added lines #L350 - L352 were not covered by tests
"text": text,
"global_index": sentid,
"local_index": sentid,
"gdd_id": gdd_id,
"doi": row['DOI'],
"timestamp": str(datetime.today()),
"chunk_hash": get_hash(text),
"article_hash": get_hash(text),
}
output_data = {

Check warning on line 362 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L362

Added line #L362 was not covered by tests
"data": article_data,
"result": output_files[sentid]
}
file_name = os.path.join(labelled_file_path, f"{gdd_id}_{sentid}.json")

Check warning on line 366 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L366

Added line #L366 was not covered by tests
# Save the dictionary as a json file
with open(file_name, "w") as f:
json.dump(output_data, f, indent=2)

Check warning on line 369 in src/preprocessing/labelling_data_split.py

View check run for this annotation

Codecov / codecov/patch

src/preprocessing/labelling_data_split.py#L368-L369

Added lines #L368 - L369 were not covered by tests

def main():
opt = docopt(__doc__)
Expand Down
1 change: 0 additions & 1 deletion src/preprocessing/labelling_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from src.logs import get_logger
# logger = logging.getLogger(__name__)
logger = get_logger(__name__)
logger.setLevel(logging.INFO)

from src.entity_extraction.baseline_entity_extraction import baseline_extract_all
from src.entity_extraction.spacy_entity_extraction import spacy_extract_all
Expand Down
Loading