From 6b7abe4ba467ed1f45a6715db039f423c1cf298d Mon Sep 17 00:00:00 2001
From: Jeremy Arancio <97704986+jeremyarancio@users.noreply.github.com>
Date: Mon, 4 Nov 2024 16:48:59 +0100
Subject: [PATCH] feat: add CLI command to pre-annotate object detection tasks
with Yolo-world (#350)
* feat: :heavy_plus_sign: Pre-annotation Yolo-world + label studio
* feat: :chart_with_upwards_trend: Annotation & Export
* feat: :zap: Add code to Crop-Detection
* refactor: :sparkles: Add __main__.py
* refactor: :art: Refactor cli command
* fix: :zap: Wrong model path fixed
---
ml_utils/ml_utils_cli/cli/annotate.py | 145 +++++++++++++
ml_utils/ml_utils_cli/cli/apps/datasets.py | 33 ++-
ml_utils/ml_utils_cli/cli/export.py | 20 +-
.../config_files/product-detection.xml | 6 +
object_detection/crop_detection/Makefile | 22 ++
object_detection/crop_detection/README.md | 114 ++++++++++
.../crop_detection/assets/cropped.jpg | Bin 0 -> 347933 bytes
.../crop_detection/assets/product.jpg | Bin 0 -> 882083 bytes
.../crop_detection/best_images/.gitkeep | 0
.../crop_detection/cli/__main__.py | 4 +
.../cli/inference_yolo_tflite.py | 203 ++++++++++++++++++
.../commands/load_images_from_aws.sh | 41 ++++
object_detection/crop_detection/data/.gitkeep | 0
.../crop_detection/images/.gitkeep | 0
.../crop_detection/models/.gitkeep | 0
.../crop_detection/requirements.txt | 4 +
16 files changed, 582 insertions(+), 10 deletions(-)
create mode 100644 ml_utils/ml_utils_cli/cli/annotate.py
create mode 100644 ml_utils/ml_utils_cli/config_files/product-detection.xml
create mode 100644 object_detection/crop_detection/Makefile
create mode 100644 object_detection/crop_detection/README.md
create mode 100644 object_detection/crop_detection/assets/cropped.jpg
create mode 100644 object_detection/crop_detection/assets/product.jpg
create mode 100644 object_detection/crop_detection/best_images/.gitkeep
create mode 100644 object_detection/crop_detection/cli/__main__.py
create mode 100644 object_detection/crop_detection/cli/inference_yolo_tflite.py
create mode 100755 object_detection/crop_detection/commands/load_images_from_aws.sh
create mode 100644 object_detection/crop_detection/data/.gitkeep
create mode 100644 object_detection/crop_detection/images/.gitkeep
create mode 100644 object_detection/crop_detection/models/.gitkeep
create mode 100644 object_detection/crop_detection/requirements.txt
diff --git a/ml_utils/ml_utils_cli/cli/annotate.py b/ml_utils/ml_utils_cli/cli/annotate.py
new file mode 100644
index 00000000..90c4462d
--- /dev/null
+++ b/ml_utils/ml_utils_cli/cli/annotate.py
@@ -0,0 +1,145 @@
+import os
+import uuid
+from typing import List, Iterable, Dict, Iterator
+from pathlib import Path
+import tqdm
+
+from ultralytics import YOLO
+from ultralytics.engine.results import Results
+
+from openfoodfacts.utils import get_logger
+
+
+logger = get_logger(__name__)
+
+IMAGE_FORMAT = [".jpg", ".jpeg", ".png"]
+MODEL_NAME = "yolov8x-worldv2.pt"
+LABELS = ["packaging"]
+
+
+def format_object_detection_sample_from_yolo(
+ images_dir: Path,
+ model_name: str,
+ labels: List[str],
+ batch_size: int,
+) -> Iterable[Dict]:
+ logger.info("Loading images from %s", images_dir)
+ image_paths = [image_path for image_path in images_dir.iterdir() if image_path.suffix in IMAGE_FORMAT]
+ logger.info("Found %d images in %s", len(image_paths), images_dir)
+ ls_data = generate_ls_data_from_images(image_paths=image_paths)
+ logger.info("Pre-annotating images with YOLO")
+ predictions = format_predictions_from_yolo(
+ image_paths=image_paths,
+ model_name=model_name,
+ labels=labels,
+ batch_size=batch_size,
+ )
+ return [
+ {
+ "data": {
+ "image_id": data["image_id"],
+ "image_url": data["image_url"],
+ "split": "train",
+ },
+ "predictions": [prediction] if prediction["result"] else [],
+ }
+ for data, prediction in zip(ls_data, predictions)
+ ]
+
+
+def generate_ls_data_from_images(image_paths: Iterable[Path]):
+ for image_path in image_paths:
+ yield {
+ "image_id": image_path.stem.replace("_", "-"),
+ "image_url": transform_id_to_url(image_path.name),
+ }
+
+
+def transform_id_to_url(image_id: str) -> str:
+ """Format image_id: 325_938_117_1114_1 => https://images.openfoodfacts.org/images/products/325/938/117/1114/1"""
+ off_base_url = "https://images.openfoodfacts.org/images/products/"
+ return os.path.join(off_base_url, "/".join(image_id.split("_")))
+
+
+def format_predictions_from_yolo(
+ image_paths: Iterable[Path],
+ model_name: str,
+ labels: List[str],
+ batch_size: int,
+) -> Iterator[Dict]:
+ results = pre_annotate_with_yolo(
+ image_paths=image_paths,
+ model_name=model_name,
+ labels=labels,
+ batch_size=batch_size,
+ )
+ for batch in results:
+ for result in batch:
+ annotation_results = []
+ orig_height, orig_width = result.orig_shape
+ model_version = model_name.split("/")[-1]
+ for xyxyn in result.boxes.xyxyn:
+ # Boxes found.
+ if len(xyxyn) > 0:
+ xyxyn = xyxyn.tolist()
+ x1 = xyxyn[0] * 100
+ y1 = xyxyn[1] * 100
+ x2 = xyxyn[2] * 100
+ y2 = xyxyn[3] * 100
+ width = x2 - x1
+ height = y2 - y1
+ annotation_results.append(
+ {
+ "id": str(uuid.uuid4())[:5],
+ "type": "rectanglelabels",
+ "from_name": "label",
+ "to_name": "image",
+ "original_width": orig_width,
+ "original_height": orig_height,
+ "image_rotation": 0,
+ "value": {
+ "rotation": 0,
+ "x": x1,
+ "y": y1,
+ "width": width,
+ "height": height,
+ "rectanglelabels": ["product"], # Label studio label
+ },
+ },
+ )
+ yield {
+ "model_version": model_version,
+ "result": annotation_results
+ }
+
+
+def pre_annotate_with_yolo(
+ image_paths: Iterable[Path],
+ model_name: str,
+ labels: List[str],
+ batch_size: int,
+ conf: float = 0.1,
+ max_det: int = 1,
+) -> Iterator[Iterable[Results]]:
+ """To fasten the annotation, we leveraged Yolo-World and its capacity to predict object using natural language.
+
+
+
+ https://docs.ultralytics.com/modes/predict/#working-with-results"""
+ model = YOLO(model_name)
+ model.set_classes(labels)
+ # Transform image_paths into batch
+ batches = _batch(image_paths, batch_size=batch_size)
+ for batch in tqdm.tqdm(batches, desc="Yolo-predictions"):
+ results = model.predict(
+ batch,
+ conf=conf,
+ max_det=max_det,
+ )
+ yield results
+
+
+def _batch(iterable: Iterable, batch_size: int) -> Iterator:
+ total = len(iterable)
+ for ndx in range(0, total, batch_size):
+ yield iterable[ndx:min(ndx + batch_size, total)]
diff --git a/ml_utils/ml_utils_cli/cli/apps/datasets.py b/ml_utils/ml_utils_cli/cli/apps/datasets.py
index 34650363..af99a9e9 100644
--- a/ml_utils/ml_utils_cli/cli/apps/datasets.py
+++ b/ml_utils/ml_utils_cli/cli/apps/datasets.py
@@ -1,3 +1,4 @@
+import os
import json
import random
import shutil
@@ -9,6 +10,8 @@
from ..config import LABEL_STUDIO_DEFAULT_URL
from ..types import ExportDestination, ExportSource, TaskType
+from ..annotate import MODEL_NAME, LABELS
+
app = typer.Typer()
@@ -176,9 +179,9 @@ def export(
raise typer.BadParameter("Output directory is required for Ultralytics export")
if from_ == ExportSource.ls:
+ ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
+ category_names_list = category_names.split(",")
if to == ExportDestination.hf:
- ls = LabelStudio(base_url=label_studio_url, api_key=api_key)
- category_names_list = category_names.split(",")
export_to_hf(ls, repo_id, category_names_list, project_id)
elif to == ExportDestination.ultralytics:
export_from_ls_to_ultralytics(
@@ -240,3 +243,29 @@ def create_dataset_file(
image_id, url, image.width, image.height, extra_meta
)
f.write(json.dumps(label_studio_sample) + "\n")
+
+
+@app.command()
+def create_dataset_file_from_yolo(
+ images_dir: Annotated[Path, typer.Option(exists=True)],
+ output_file: Annotated[Path, typer.Option(exists=False)],
+ model_name: str = MODEL_NAME,
+ models_dir: str = "models",
+ labels: list[str] = LABELS,
+ batch_size: int = 20,
+):
+ """Create a Label Studio object detection dataset file from a list of images.
+ Add pre-annotations using YOLO model (such as Yolo-World).
+ """
+ from cli.annotate import format_object_detection_sample_from_yolo
+ model_name = os.path.join(models_dir, model_name)
+ samples = format_object_detection_sample_from_yolo(
+ images_dir=images_dir,
+ model_name=model_name,
+ labels=labels,
+ batch_size=batch_size,
+ )
+ logger.info("Saving samples to %s", output_file)
+ with output_file.open("wt") as f:
+ for sample in samples:
+ f.write(json.dumps(sample) + "\n")
\ No newline at end of file
diff --git a/ml_utils/ml_utils_cli/cli/export.py b/ml_utils/ml_utils_cli/cli/export.py
index 3237c81c..1b7ea6dd 100644
--- a/ml_utils/ml_utils_cli/cli/export.py
+++ b/ml_utils/ml_utils_cli/cli/export.py
@@ -72,12 +72,13 @@ def export_from_ls_to_ultralytics(
data_dir = output_dir / "data"
data_dir.mkdir(parents=True, exist_ok=True)
-
+
+ # NOTE: before, all images were sent to val, the last split
+ label_dir = data_dir / "labels"
+ images_dir = data_dir / "images"
for split in ["train", "val"]:
- split_labels_dir = data_dir / "labels" / split
- split_labels_dir.mkdir(parents=True, exist_ok=True)
- split_images_dir = data_dir / "images" / split
- split_images_dir.mkdir(parents=True, exist_ok=True)
+ (label_dir / split).mkdir(parents=True, exist_ok=True)
+ (images_dir / split).mkdir(parents=True, exist_ok=True)
for task in tqdm.tqdm(
ls.tasks.list(project=project_id, fields="all"),
@@ -92,8 +93,11 @@ def export_from_ls_to_ultralytics(
continue
annotation = task.annotations[0]
- image_id = task.data["image_id"]
+ if annotation["was_cancelled"] is True:
+ logger.debug("Annotation was cancelled, skipping")
+ continue
+ image_id = task.data["image_id"]
image_url = task.data["image_url"]
download_output = download_image(image_url, return_bytes=True)
if download_output is None:
@@ -102,10 +106,10 @@ def export_from_ls_to_ultralytics(
_, image_bytes = download_output
- with (split_images_dir / f"{image_id}.jpg").open("wb") as f:
+ with (images_dir / split / f"{image_id}.jpg").open("wb") as f:
f.write(image_bytes)
- with (split_labels_dir / f"{image_id}.txt").open("w") as f:
+ with (label_dir / split / f"{image_id}.txt").open("w") as f:
for annotation_result in annotation["result"]:
if annotation_result["type"] != "rectanglelabels":
raise ValueError(
diff --git a/ml_utils/ml_utils_cli/config_files/product-detection.xml b/ml_utils/ml_utils_cli/config_files/product-detection.xml
new file mode 100644
index 00000000..edb31a40
--- /dev/null
+++ b/ml_utils/ml_utils_cli/config_files/product-detection.xml
@@ -0,0 +1,6 @@
+
+
+
+
}m}5d?EVK*x|e8oF__~QOylEoSd-zpqEKoA~JNPQs|@354cXV)SPW*;8<4J
zSxbqUkNKe^^X{+X4}Z%PkWEXNZm1 @Qn3$$oUV-+*g41y$cd(F-7
zI*_lC>X$!%&M_4*6wA`CR$$aDeG~@lk^k{7^Bf*;d6Yf;%a$XBuN0=;6QnFjO%Ok
zbv9(Kj4ur2v7inG`3ANThVP3UQ%9l_$JyCF7@(6S5v(eXN^ZGcWK)A)sZ3LIjbSF~
z`U;%$Pea
#OX+Q2U{x12J@$o9gL;d+Kquio
zuk1M$-xoTJpe5U$8(n#9MQWFs!7CGXx6WPXy;G