diff --git a/cocomltools/coco_ops.py b/cocomltools/coco_ops.py index d18556a..1bcdb60 100644 --- a/cocomltools/coco_ops.py +++ b/cocomltools/coco_ops.py @@ -5,6 +5,11 @@ from collections import defaultdict from PIL import Image from pathlib import Path +import numpy as np +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed +import asyncio +import time class CocoOps: @@ -83,24 +88,47 @@ def _crop_and_save_one_ann( x1, y1, w, h = ann.bbox x2, y2 = x1 + w, y1 + h - category_name = self.cat_ids_to_names[ann.category_id] + category_name = self.coco.cat_ids_to_names[ann.category_id] category_dir = Path(output_dir) / category_name category_dir.mkdir(exist_ok=True, parents=True) crop_out_file = category_dir / f"{ann.id}.jpg" crop = image.crop((x1, y1, x2, y2)) crop.save(crop_out_file) + def _crop_one_image(self, elem: Image, images_dir: Path, output_dir: Path): + file_image = Path(images_dir) / elem.file_name + annotations = self.coco.get_annotation_by_image_id(elem.id) + if len(annotations) == 0: # if no annotations, skip + return + image = Image.open(file_image).convert("RGB") + for ann in annotations: + self._crop_and_save_one_ann(image, ann, output_dir) + def crop( self, - images_dir: str, - output_dir: str, + images_dir: Path, + output_dir: Path, + max_workers: int = 1, ): + if isinstance(images_dir, str): + images_dir = Path(images_dir) + + if isinstance(output_dir, str): + output_dir = Path(output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + + start_time = time.time() + with ThreadPoolExecutor(max_workers=max_workers) as executor: - for elem in self.images: - file_image = Path(images_dir) / elem.file_name - annotations = self.coco.get_annotation_by_image_id(elem.id) - if len(annotations) == 0: # if no annotations, skip - continue - image = Image.open(file_image).convert("RGB") - for ann in annotations: - self._crop_and_save_one_ann(image, ann, output_dir) + futures = [ + executor.submit(self._crop_one_image, elem, images_dir, output_dir) + for elem in self.coco.images + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + print(f"Error processing annotation: {e}") + elapsed_time = time.time() - start_time + print(f"Completed cropping dataset in {elapsed_time:.2f} seconds") diff --git a/cocomltools/models/base.py b/cocomltools/models/base.py index 78e630d..9d4c8c3 100644 --- a/cocomltools/models/base.py +++ b/cocomltools/models/base.py @@ -16,7 +16,7 @@ class Annotation(BaseModel): score: float = Field(default=1.0) bbox: List[float] segmentation: List[float] | List[List[float]] = Field(default=[]) - area: int + area: float iscrowd: int = Field(default=0) diff --git a/main_cli.py b/main_cli.py index 6276d8c..17841a7 100644 --- a/main_cli.py +++ b/main_cli.py @@ -54,7 +54,7 @@ def crop_cmd(args): Path(args.output_dir) if args.output_dir else images_dir.parent / "cropped" ) output_dir.mkdir(exist_ok=True, parents=True) - coco_ops.crop(images_dir, output_dir) + coco_ops.crop(images_dir, output_dir, max_workers=args.num_workers) def parse_args(): @@ -96,6 +96,13 @@ def parse_args(): "--images-dir", required=True, help="Path to coco image files" ) parser_crop.add_argument("--output-dir", required=False, help="Path to output dir") + parser_crop.add_argument( + "--num-workers", + required=False, + default=1, + type=int, + help="number of workers to crop the dataset", + ) args = parser.parse_args() return args