From 5ff5b683ab17db48d6be85cd01096ff165132fc1 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 8 Nov 2023 13:36:45 -0800 Subject: [PATCH] Continue testing api-only interface --- scripts/start_db_container.sh | 4 ++-- trapdata/api/queries.py | 16 +++++++++---- trapdata/ml/models/localization.py | 38 +++++++++++++++++++++++------- trapdata/ml/pipeline.py | 6 ++--- 4 files changed, 46 insertions(+), 18 deletions(-) diff --git a/scripts/start_db_container.sh b/scripts/start_db_container.sh index 307e5acb..abc8586a 100755 --- a/scripts/start_db_container.sh +++ b/scripts/start_db_container.sh @@ -4,7 +4,7 @@ set -o errexit set -o nounset CONTAINER_NAME=ami-db -HOST_PORT=5432 +HOST_PORT=5433 POSTGRES_VERSION=14 POSTGRES_DB=ami @@ -12,5 +12,5 @@ docker run -d -i --name $CONTAINER_NAME -v "$(pwd)/db_data":/var/lib/postgresql/ docker logs ami-db --tail 100 -echo 'Database started, Connection string: "postgresql://postgres@localhost:5432/ami"' +echo "Database started, Connection string: \"postgresql://postgres@localhost:${HOST_PORT}/${POSTGRES_DB}\"" echo "Stop (and destroy) database with 'docker stop $CONTAINER_NAME' && docker remove $CONTAINER_NAME" diff --git a/trapdata/api/queries.py b/trapdata/api/queries.py index 5c3cc79b..58880ed2 100644 --- a/trapdata/api/queries.py +++ b/trapdata/api/queries.py @@ -6,11 +6,16 @@ logger = logging.getLogger(__name__) +TEMPORARY_DEPLOYMENT_ID = 9 +TEMPORARY_EVENT_ID = 34 +TEMPORARY_COLLECTION_ID = 4 + def save_detected_objects( source_image_ids: list[int], detected_objects_data: list[dict], *args, **kwargs ): logger.info(f"Saving {len(source_image_ids)} detected objects via API") + print(f"Saving {len(source_image_ids)} detected objects via API") responses = {} path = "detections/" for source_image_id, detected_objects in zip( @@ -38,9 +43,11 @@ def get_next_source_images(num: int, *args, **kwargs) -> list[IncomingSourceImag path = "captures/" args = { "limit": num, - "deployment": 9, - "event": 34, + # "deployment": TEMPORARY_DEPLOYMENT_ID, + # "event": TEMPORARY_EVENT_ID, + "collections": TEMPORARY_COLLECTION_ID, "has_detections": False, + "order": "?", } # last_processed__isnull=True url = settings.api_base_url + path resp = get_session().get(url, params=args) @@ -53,8 +60,9 @@ def get_next_source_images(num: int, *args, **kwargs) -> list[IncomingSourceImag def get_source_image_count(*args, **kwargs) -> int: path = "captures/" args = { - "deployment": 9, - "event": 34, + # "deployment": TEMPORARY_DEPLOYMENT_ID, + # "event": TEMPORARY_EVENT_ID, + "collections": TEMPORARY_COLLECTION_ID, "limit": 1, "has_detections": False, } diff --git a/trapdata/ml/models/localization.py b/trapdata/ml/models/localization.py index 2101237a..fa4e3419 100644 --- a/trapdata/ml/models/localization.py +++ b/trapdata/ml/models/localization.py @@ -7,11 +7,14 @@ import torchvision.models.detection.backbone_utils import torchvision.models.detection.faster_rcnn import torchvision.models.mobilenetv3 +from PIL import ImageFile from trapdata import TrapImage, db, logger from trapdata.ml.models.base import InferenceBaseClass from trapdata.ml.utils import get_or_download_file +ImageFile.LOAD_TRUNCATED_IMAGES = True + class LocalizationIterableDatabaseDataset(torch.utils.data.IterableDataset): def __init__(self, queue, image_transforms, batch_size=1): @@ -34,20 +37,29 @@ def __iter__(self): records = get_next_source_images(self.batch_size) logger.debug(f"Pulling records: {records}") if records: - item_ids = torch.utils.data.default_collate( - [record.id for record in records] - ) - batch_data = torch.utils.data.default_collate( - [self.transform(record.url) for record in records] - ) + items = [(record.id, self.transform(record.url)) for record in records] + items = [(item_id, item) for item_id, item in items if item is not None] + item_ids, item_objs = zip(*items) + logger.info(f"Procesing items: {item_ids}") + item_ids = torch.utils.data.default_collate(item_ids) + batch_data = torch.utils.data.default_collate(item_objs) yield (item_ids, batch_data) def transform(self, url): url = url + "?width=5000&redirect=False" logger.info(f"Fetching and transforming: {url}") img_path = get_or_download_file(url, destination_dir="/tmp/today/") - return self.image_transforms(PIL.Image.open(img_path)) + try: + return self.image_transforms(PIL.Image.open(img_path)) + except PIL.UnidentifiedImageError: + logger.error(f"Unidentified image: {img_path}") + print(f"Unidentified image: {img_path}") + return None + except OSError: + logger.error(f"OSError: {img_path}") + print(f"OSError: {img_path}") + return None class LocalizationDatabaseDataset(torch.utils.data.Dataset): @@ -86,8 +98,16 @@ def __getitem__(self, idx): sesh.commit() img_path = img_path - pil_image = PIL.Image.open(img_path) - item = (item_id, self.transform(pil_image)) + try: + pil_image = PIL.Image.open(img_path) + except FileNotFoundError: + logger.error(f"File not found: {img_path}") + item = (item_id, None) + except PIL.UnidentifiedImageError: + logger.error(f"Unidentified image: {img_path}") + item = (item_id, None) + else: + item = (item_id, self.transform(pil_image)) return item diff --git a/trapdata/ml/pipeline.py b/trapdata/ml/pipeline.py index 9b5bf600..53556358 100644 --- a/trapdata/ml/pipeline.py +++ b/trapdata/ml/pipeline.py @@ -23,9 +23,9 @@ def start_pipeline( num_workers=settings.num_workers, single=single, ) - if object_detector.queue.queue_count() > 0: - object_detector.run() - logger.info("Localization complete") + # if object_detector.queue.queue_count() > 0: + object_detector.run() + logger.info("Localization complete") BinaryClassifier = ml.models.binary_classifiers[ settings.binary_classification_model.value