Skip to content

Commit

Permalink
fix(nutrisight): improve scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 26, 2024
1 parent 405eb1d commit 6faa0c5
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 30 deletions.
85 changes: 60 additions & 25 deletions nutrisight/dataset/6_push_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def create_sample(task: Task, only_checked: bool = False) -> Optional[dict]:
task_data = task.data
annotations_data = task.annotations
task_id = task.id

split = task_data["split"]
if split not in ("train", "test"):
raise ValueError("Task %s has an invalid split: %s", task_id, split)

if len(annotations_data) == 0:
# No annotation, skip
return None
Expand All @@ -45,9 +50,11 @@ def create_sample(task: Task, only_checked: bool = False) -> Optional[dict]:

ocr_url = task_data["meta"]["ocr_url"]
image_url = task_data["image_url"]

meta = {
"barcode": extract_barcode_from_url(image_url),
"image_id": Path(urlparse(image_url).path).stem,
"split": split,
"ocr_url": ocr_url,
"image_url": task_data["image_url"],
"batch": task_data["batch"],
Expand Down Expand Up @@ -130,7 +137,7 @@ def create_sample(task: Task, only_checked: bool = False) -> Optional[dict]:
):
meta[info_name.replace("-", "_")] = info_name in info_checkbox
elif result["from_name"] == "issues" and result["value"]["choices"]:
logger.info(
logger.debug(
"Task %s has issues: %s, skipping",
task_id,
result["value"]["choices"],
Expand Down Expand Up @@ -183,7 +190,8 @@ def get_tasks(
api_key: str,
project_id: int,
batch_ids: list[int] | None = None,
) -> Iterator[dict]:
limit: int | None = None,
) -> Iterator[Task]:
"""Yield tasks (annotations) from Label Studio."""

ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
Expand All @@ -198,22 +206,27 @@ def get_tasks(
"value": "batch-{}$".format("|".join(map(str, batch_ids))),
}
)
yield from ls.tasks.list(
project=project_id,
query=(
{
"filters": {
"conjunction": "and",
"items": filter_items,
for i, task in enumerate(
ls.tasks.list(
project=project_id,
query=(
{
"filters": {
"conjunction": "and",
"items": filter_items,
}
}
}
if filter_items
else None
),
# This view contains all annotated samples
view=88,
fields="all",
)
if filter_items
else None
),
# This view contains all annotated samples
view=88,
fields="all",
)
):
if limit is not None and i >= limit:
break
yield task


def sample_generator(dir_path: Path):
Expand All @@ -230,25 +243,35 @@ def push_dataset(
project_id: Annotated[int, typer.Option(..., help="Label Studio project ID")] = 42,
batch_ids: Annotated[
Optional[str],
typer.Option(..., help="comma-separated list of batch IDs to include"),
typer.Option(
...,
help="comma-separated list of batch IDs to include. If not provided, "
"all batches are included",
),
] = None,
label_studio_url: Annotated[
str, typer.Option()
] = "https://annotate.openfoodfacts.org",
revision: Annotated[
str, typer.Option(help="Dataset revision on Hugging Face datasets")
] = "main",
test_split_count: Annotated[
int, typer.Option(help="Number of samples in test split")
] = 200,
only_checked: Annotated[
bool, typer.Option(help="Include only checked tasks", show_default=False)
bool,
typer.Option(
help="Include only checked tasks. If False, all annotated tasks "
"are included.",
show_default=False,
),
] = False,
):
"""Push the nutrition extraction dataset to Hugging Face datasets, from
Label Studio."""
logger.info("Fetching tasks from Label Studio, project %s", project_id)
if batch_ids:
batch_ids_int = list(map(int, batch_ids.split(",")))
logger.info("Fetching tasks for batches %s", batch_ids_int)
else:
batch_ids_int = None

created = 0
ner_tag_set: set[str] = set()
Expand All @@ -259,7 +282,12 @@ def push_dataset(

for i, task in enumerate(
tqdm.tqdm(
get_tasks(label_studio_url, api_key, project_id, batch_ids_int),
get_tasks(
label_studio_url=label_studio_url,
api_key=api_key,
project_id=project_id,
batch_ids=batch_ids_int,
),
desc="tasks",
)
):
Expand Down Expand Up @@ -298,6 +326,7 @@ def push_dataset(
"barcode": datasets.Value("string"),
"image_id": datasets.Value("string"),
"image_url": datasets.Value("string"),
"split": datasets.Value("string"),
"ocr_url": datasets.Value("string"),
"batch": datasets.Value("string"),
"label_studio_id": datasets.Value("int64"),
Expand All @@ -312,8 +341,14 @@ def push_dataset(
dataset = datasets.Dataset.from_generator(
functools.partial(sample_generator, tmp_dir), features=features
)
dataset = dataset.train_test_split(
test_size=test_split_count, shuffle=False, seed=42

train_subset = dataset.filter(lambda example: example["meta"]["split"] == "train")
test_subset = dataset.filter(lambda example: example["meta"]["split"] == "test")
dataset = datasets.DatasetDict(
{
"train": train_subset,
"test": test_subset,
}
)
logger.info(
"Pushing dataset to Hugging Face Hub under openfoodfacts/nutrient-detection-layout, revision %s",
Expand Down
6 changes: 3 additions & 3 deletions nutrisight/dataset/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
typer==0.12.3
label-studio-sdk==1.0.2
openfoodfacts==0.3.0
Pillow==10.3.0
datasets==2.18.0
openfoodfacts==2.5.0
Pillow==11.0.0
datasets==3.2.0
redis==5.0.6
ratelimit==2.2.1
more_itertools==10.3.0
6 changes: 4 additions & 2 deletions nutrisight/train/launch.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
RUN_NAME='ds-v5-large'
RUN_NAME='ds-v6-large'
BASE_MODEL_NAME='microsoft/layoutlmv3-large'

DISABLE_MLFLOW_INTEGRATION="TRUE" WANDB_PROJECT=nutrition-detector WANDB_NAME=$RUN_NAME \
Expand All @@ -22,8 +22,10 @@ train.py \
--save_steps 15 \
--evaluation_strategy steps \
--save_strategy steps \
--evaluation_strategy steps \
--metric_for_best_model "eval_f1" \
--learning_rate 1e-5 \
--push_to_hub \
--hub_model_id "openfoodfacts/nutrition-extractor" \
--hub_strategy "end"
--hub_strategy "end" \
--load_best_model_at_end

0 comments on commit 6faa0c5

Please sign in to comment.