Skip to content

Commit

Permalink
feat: Spellcheck (#345)
Browse files Browse the repository at this point in the history
* feat(spellcheck): ⚡ Add feature: push to Argilla from an HF dataset

* fix(spellcheck): 🐛 Add fixes to T5 script

* fix(spellcheck): 🐛 Add guardrail to prevent compiuting metrics with empty strings

* feat(spellcheck): 🎨 Training pipeline using Metaflow

* feat(spellcheck): 🎨 LLM QLoRA TRL training script - Mistral - 7B - Instruct

* perf(spellcheck): 🧪 Normalize evaluation algorithm

"flavour" -> "flavor" - "ï" -> "i" - "â" -> "a" - "oe"

* feat(spellcheck): 🎨 Implement LLM training with Sagemaker & Metaflow

Scripts are customed to handle training in the cloud using Sagemaker Training Jobs

* feat(spellcheck): ⚡ Mistral 7b instruct v3 trained

* feat(spellcheck): 🎨 Update guidelines: accents

* refactor(spellcheck): ✨ Update Logging to consider script and src code for defining the level

* feat(spellcheck): ✨ Dataset processing methods & pipeline created

* build(spellcheck): ✨ Dataset processing (oe, percentage alignment): v3.1

* feat(spellcheck): ♻️ Training Mistral-7B-Instruct: instruction + unidecode normalization + scheduler linear

* Delete previous training lllm dag

* feat(spellcheck): ⚡ Add eval normalization: remove "\n"

* refactor(spellcheck): 👷 Foundational LLMs re-evaluated on the benchmark

Prompt was intentionally overfitted on the benchmark to create later the synthetic training dataset. Examples from benchmark are removed from the prompt.

* Modify overfitted prompt

* refactor(spellcheck): 🏷️ Refactor Argilla extraction: modules + unitests + dag

* refactor(spellcheck): ⬆️ Refactor training job: add parameters to argparse + add cometML logs

* feat(spellcheck): 🎨 Fine-tune Mistral-7b with guidelines + training scripts improvments

* feat(spellcheck): ✨ Add args for training + Train Mistral-7b-base

* fix(spellcheck): 🐛 Correction error in Mistral-7B-Base fine-tuning script

* feat(spellcheck): 🎨 DPO dataset extraction and push

* fix(spellcheck): 🔖 small fixes

* feat(spellcheck): ⚡ DPO training script

* refactor(spellcheck): 🚧 Refactor training pipeline: WIP

* Update get_logger for Metaflow logging

* feat(spellcheck): ✨ Double the benchmark size: extraction and push to Argilla pipeline (WIP)

* docs(spellcheck): 📝 Document benchmark generation pipeline

* feat(spellcheck): 🐛 Remove legacy metadata in Argilla

* refactor(spellcheck): 🚧 Refactor training pipeline (WIP)

* refactor(spellcheck): 🚧 Refactor training script (WIP)

* refactor(spellcheck): 🚧 Refactor training pipeline

* chore(spellcheck): ✨ Update Python from 3.9 to 3.10

* refactor(spellcheck): ⚡ LLM training pipeline refactored

* feat(spellcheck): ✨ Pretraining before fine-tuning (WIP)

* feat(spellcheck): 🚑 Pretraining + Finetuning Mistral-7B

* refactor(spellcheck): ✨ Refactor training

* feat(spellcheck): 🚧 Batch processing  (WIP)

* feat(spellcheck): 🚧 Batch job with vllm and GCP (WIP)

* feat(spellcheck): ⚡ Batch job operational on GCP

* refactor(spellcheck): ✨ Clean code and add logging to batch job

* fix(spellcheck): 📦 Forgot to add batch dep requirements
  • Loading branch information
jeremyarancio authored Aug 23, 2024
1 parent 60b07a2 commit 24adbb2
Show file tree
Hide file tree
Showing 50 changed files with 5,070 additions and 1,471 deletions.
17 changes: 17 additions & 0 deletions spellcheck/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel

ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PIP_DISABLE_PIP_VERSION_CHECK=on \
PYTHONPATH="/app/src"

WORKDIR /app

COPY ./src /app

COPY ./scripts/batch/. /app

RUN pip install --no-cache-dir -r requirements.txt

# Set the entrypoint to the batch job script
ENTRYPOINT ["python", "main.py"]
33 changes: 18 additions & 15 deletions spellcheck/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ From the different types of errors observed across products, we came up with the
* The only case when a whitespace involving a percentage should be modified is if the *digit* is stuck in the previous word (*ex: cheese1.9% -> cheese 1.9%*)
* Some ingredients are enclosed with `_`, such as `_milk_` or `_cacahuetes_`, to detect allergens. Should remain unchanged. However, in the case it is not an ingredient, such as `_Cacahuetes_ con cáscara tostado. _Trazas de frutos de cáscara_.`, it needs to be modified into `_Cacahuetes_ con cáscara tostado. Trazas de frutos de cáscara.`;
* Some percentages were badly parsed by the OCR. Since we cannot be sure about what is the right value, it is preferable to keep it as it is.
* We're ok with accents modified or not.
* Accents and other language specific punctuations:
* In Romanian, the characters ["ş" (351), "ţ" (355)] (ASCII id) should be retrieved by the Spellcheck when necessary,
* Uppercase letters should remain unchanged => "ECOSSE" -> "ECOSSE"; "ÉCOSSE" -> "ÉCOSSE"
* If lowercase, accent should be added if missing.
* `*` should remain in the corrected text as much as possible (*ex: Schweinefleisch\* -> Schweinefleisch\**)
* Whitespaces shouldn't been modified except for these cases:
* When two words are stuck to each other: *"rizbrun -> riz brun*
Expand Down Expand Up @@ -233,30 +236,30 @@ We evaluated **Proprietary LLMs** such as OpenAI GPTs and Anthropic Claude 3 mod

Texts are normalized to not consider some typical corrections:
* lowercase-uppercase
* whitespaces between words
* words are stripped (whitespace)
* replace ("œ", "oe")
* replace ("flavour", "flavor") - ("colour", "color") - ("pasteurized", "pasteurised")
* removed all accent using the Unidecode library
* remove linebreaks: ("\n", "")

In addition to computing metrics using the evaluation algorithm, predictions against the benchmark are pushed to Argilla for human evaluation. The proportion of good corrections is then calculated.

Benchmark version: **v5**
Prompt version: **v6**
Benchmark version: **v7.3** -- Prompt version: **v7**


| Model | Correction Precision | Correction Recall | Localisation Precision | Localisation Recall | Localisation F1 | Human evaluation
|----------|----------|----------|----------|----------|----------|----------|
| FlanT5-Small | **0.815** | 0.486 | **0.876** | 0.522 | 0.654 | - |
| GPT-3.5-Turbo | 0.729 | **0.779** | 0.767 | **0.820** | **0.793** | **0.894** |
| Gemini-1.0-pro | 0.499 | 0.586 | 0.561 | 0.658 | 0.605 | 0.717 |
| Gemini-1.5-flash | 0.514 | 0.693 | 0.590 | 0.795 | 0.677 | 0.790 |
| Gemini-1.5-pro | 0.364 | 0.658 | 0.415 | 0.750 | 0.534 | - |
| Mistral-7B-Instruct-v3 (not fine-tuned) | 0.381 | 0.501 | 0.488 | 0.641 | 0.554 | - |
| Model | Correction Precision | Correction Recall | Correction F1 | Human evaluation
|----------|----------|----------|----------|----------|
| GPT-3.5-Turbo | 0.557 | 0.727 | 0.631 | - |
| GPT-4o | 0.311 | 0.702 | 0.431 |
| Gemini-1.5-flash | 0.544 | 0.596 | 0.569 | - |
| Claude3-Sonnet-3.5 | 0.178 | **0.810** | 0.292 | - |
| **Our model** | **0.664** | 0.630 | **0.647** | - |


Notes:
* **Correction Precision**: Proportion of correct modifications.
* **Correction Recall**: Proportion of errors found and corrected
* **Localisation Precision**: Proportion of errors rightly detected by the model
* **Localisation Recall**: Proportion of errors founded
* **Localisation F1**: Mean-like between Precision and Recall
* **Correction F1**: Mean-like between Precision and Recall
* **Human evaluation**: Proportion of good corrections after human analysis

### 100 % known-ingredients products
Expand Down
9 changes: 9 additions & 0 deletions spellcheck/commands/extract_from_argilla.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
python scripts/dags/extract_from_argilla.py run \
--deploy_to_hf true \
--local_path data/dataset/deployed_data.parquet \
--argilla_dataset_name training_dataset \
--dataset_hf_repo openfoodfacts/spellcheck-dataset \
--dataset_revision v4 \
--dataset_test_size 0.1 \
--dataset_version v4.3 \
--status submitted
6 changes: 6 additions & 0 deletions spellcheck/commands/run_training.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python scripts/dags/training/training.py run \
--do_human_eval False \
--evaluation_data_version v8.0 \
--training_data_version v5.2 \
--experiment_tag eval_loss --experiment_tag mistral-7b-v0.3 --experiment_tag eval-normalization --experiment_tag test

40 changes: 0 additions & 40 deletions spellcheck/config/training.yml

This file was deleted.

50 changes: 50 additions & 0 deletions spellcheck/config/training/pretraining_conf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
estimator:
entry_point: "pretraining_llm.py" # train script
source_dir: "scripts/training/llm/" # directory containing training script and requirements requirements.
dependencies:
- "src/" # Additional local library
output_path: "s3://open-food-facts-robotoff/spellcheck/model-training/" # s3 path to save the artifacts
code_location: "s3://open-food-facts-robotoff/spellcheck/model-training/" # s3 path to stage the code during the training job
base_job_name: "mistral-7b-v03" # name of the training job
instance_count: 1 # the number of instances used for training
instance_type: "ml.g5.2xlarge" # instances type used for the training job
transformers_version: "4.36" # transformers version used in the training job
pytorch_version: "2.1" # pytorch_version version used in the training job
py_version: "py310" # python version used in the training job
disable_output_compression: true # not compress output to save training time and cost
volume_size: 300 # the size of the EBS volume in GB

hyperparameters:
# Data
training_data: "openfoodfacts/spellcheck-corpus"
train_split: "train"

# Trainer
output_dir: "/opt/ml/model"
pretrained_model_name: "mistralai/Mistral-7B-v0.3"
num_train_epochs: 1
per_device_train_batch_size: 4
learning_rate: 0.0002 # Paper https://arxiv.org/pdf/2210.11416
warmup_steps: 0
warmup_ratio: 0.1
weight_decay: 0.1
gradient_checkpointing: true
seed: 42
optim: "adamw_torch_fused" # The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or adafactor.
lr_scheduler_type: "cosine"
gradient_accumulation_steps: 8
bf16: true
tf32: true
fp16: false
logging_steps : 1
save_total_limit: 1
report_to: "none" # Important to avoid superposition of Trainer callback and our custom callback
max_seq_length: 2048
packing: true
dataset_text_field: "ingredients_text"
# add_special_tokens: true # Add bos token and other special token from the tokenizer
# append_concat_token: true # If true, appends eos_token_id at the end of each sample being packed.

# Saving
merge_weights: true
max_shard_size: "2GB"
75 changes: 75 additions & 0 deletions spellcheck/config/training/training_conf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
estimator:
entry_point: "refactored_llm.py" # train script
source_dir: "scripts/training/llm/" # directory containing training script and requirements requirements.
dependencies:
- "src/" # Additional local library
output_path: "s3://open-food-facts-robotoff/spellcheck/model-training/" # s3 path to save the artifacts
code_location: "s3://open-food-facts-robotoff/spellcheck/model-training/" # s3 path to stage the code during the training job
base_job_name: "mistral-7b-v03" # name of the training job
instance_count: 1 # the number of instances used for training
instance_type: "ml.g5.2xlarge" # instances type used for the training job
transformers_version: "4.36" # transformers version used in the training job
pytorch_version: "2.1" # pytorch_version version used in the training job
py_version: "py310" # python version used in the training job
disable_output_compression: true # not compress output to save training time and cost
volume_size: 300 # the size of the EBS volume in GB

additional_conf:
s3_evaluation_uri: "s3://open-food-facts-robotoff/spellcheck/evaluation_output/"

hyperparameters:
# Data
training_data: "openfoodfacts/spellcheck-dataset"
evaluation_data: "openfoodfacts/spellcheck-benchmark"
train_split: "train+test"
eval_split: "train"
train_text_feature: "original"
train_label_feature: "reference"
eval_text_feature: "original"
eval_label_feature: "reference"
train_data_revision: "v5"
eval_data_revision: "v8"

# TrainingArguments
output_dir: "/opt/ml/model"
pretrained_model_name: "mistralai/Mistral-7B-v0.3"
num_train_epochs: 0.01
per_device_train_batch_size: 8
per_device_eval_batch_size: 4
learning_rate: 0.0002 # Paper https://arxiv.org/pdf/2210.11416
warmup_steps: 0
warmup_ratio: 0.1
weight_decay: 0.1
gradient_checkpointing: true
seed: 42
optim: "adamw_torch_fused" # The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or adafactor.
lr_scheduler_type: "cosine"
gradient_accumulation_steps: 4
bf16: true
tf32: true
fp16: false
logging_steps : 5
evaluation_strategy: "steps"
save_strategy: "steps"
eval_steps: 10 # Careful, need to be a multiple of eval-steps: 500 by default
save_total_limit: 1
report_to: "none" # Important to avoid superposition of Trainer callback and our custom callback

# SFTConfig
max_seq_length: 1024
packing: true
dataset_text_field: "text"
# add_special_tokens: true # Add bos token and other special token from the tokenizer
# append_concat_token: true # If true, appends eos_token_id at the end of each sample being packed.

# Saving
merge_weights: true
max_shard_size: "2GB"

# Inference
max_new_token: 1024
batch_size: 1

#Data processing
batched: false
# instruction_template
85 changes: 85 additions & 0 deletions spellcheck/data/evaluation/metrics.jsonl
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,88 @@
"prompt_version": "v6",
"benchmark_size": 151
}
{
"metrics": {
"correction_precision": 0.7101024890190337,
"correction_recall": 0.7860615883306321,
"precision": 0.7481698389458272,
"recall": 0.8282009724473258,
"f1": 0.7861538461538463,
"f1_beta": 0.7861538461538463,
"beta": 1.0,
"drop_count": 0
},
"model": "gpt-3.5-turbo",
"date": "25/06/2024 10:19:48",
"benchmark_version": "v5",
"prompt_version": "v6",
"benchmark_size": 151
}
{
"metrics": {
"correction_precision": 0.5573366214549939,
"correction_recall": 0.7266881028938906,
"precision": 0.6091245376078915,
"recall": 0.7942122186495176,
"f1": 0.6894626657362177,
"f1_beta": 0.6894626657362177,
"beta": 1.0,
"drop_count": 0
},
"model": "gpt-3.5-turbo",
"date": "01/07/2024 13:08:32",
"benchmark_version": "v7.3",
"prompt_version": "v7",
"benchmark_size": 151
}
{
"metrics": {
"correction_precision": 0.5439882697947214,
"correction_recall": 0.5964630225080386,
"precision": 0.6304985337243402,
"recall": 0.6913183279742765,
"f1": 0.6595092024539878,
"f1_beta": 0.6595092024539878,
"beta": 1.0,
"drop_count": 0
},
"model": "gemini-1.5-flash-preview-0514",
"date": "01/07/2024 15:19:04",
"benchmark_version": "v7.3",
"prompt_version": "v7",
"benchmark_size": 151
}
{
"metrics": {
"correction_precision": 0.17844767844767845,
"correction_recall": 0.809748427672956,
"precision": 0.19889119889119888,
"recall": 0.9025157232704403,
"f1": 0.3259511641113004,
"f1_beta": 0.3259511641113004,
"beta": 1.0,
"drop_count": 0
},
"model": "claude-3-5-sonnet-20240620",
"date": "01/07/2024 16:07:23",
"benchmark_version": "v7.3",
"prompt_version": "v7",
"benchmark_size": 152
}
{
"metrics": {
"correction_precision": 0.31130063965884863,
"correction_recall": 0.7019230769230769,
"precision": 0.35252309879175553,
"recall": 0.7948717948717948,
"f1": 0.4884293451501724,
"f1_beta": 0.4884293451501724,
"beta": 1.0,
"drop_count": 0
},
"model": "gpt-4o-2024-05-13",
"date": "01/07/2024 16:34:15",
"benchmark_version": "v7.3",
"prompt_version": "v7",
"benchmark_size": 151
}
Loading

0 comments on commit 24adbb2

Please sign in to comment.