diff --git a/ami/exports/management/commands/import_pipeline_results.py b/ami/exports/management/commands/import_pipeline_results.py new file mode 100644 index 000000000..04440c8a1 --- /dev/null +++ b/ami/exports/management/commands/import_pipeline_results.py @@ -0,0 +1,286 @@ +import datetime +import json + +import pydantic +from dateutil.parser import parse as parse_date +from django.core.management.base import BaseCommand + +from ami.main.models import Classification, Deployment, Detection, Event, Occurrence, Project, SourceImage, Taxon +from ami.ml.models import Algorithm + + +class IncomingDetection(pydantic.BaseModel): + id: int + source_image_id: int + source_image_path: str + source_image_width: int + source_image_height: int + source_image_filesize: int + label: str + score: float + cropped_image_path: str | None = None + sequence_id: str | None = None # This is the Occurrence ID on the ADC side (= detections in a sequence) + timestamp: datetime.datetime + detection_algorithm: str | None = None # Name of the object detection algorithm used + classification_algorithm: str | None = None # Classification algorithm used to generate the label & score + bbox: list[int] # Bounding box in the format [x_min, y_min, x_max, y_max] + + +class IncomingOccurrence(pydantic.BaseModel): + id: str + label: str + best_score: float + start_time: datetime.datetime + end_time: datetime.datetime + duration: datetime.timedelta + deployment: str + event: str + num_frames: int + # cropped_image_path: pathlib.Path + # source_image_id: int + examples: list[ + IncomingDetection + ] # These are the individual detections with source image data, bounding boxes and predictions + example_crop: str | None = None + # detections: list[object] + # deployment: object + # captures: list[object] + + +class Command(BaseCommand): + r"""Import trap data from a JSON file exported from the AMI data companion. + + occurrences.json + + # CURRENT EXAMPLE JSON STRUCTURE: + { + "id":"SEQ-91", + "label":"Azochis rufidiscalis", + "best_score":0.4857344627, + "start_time":"2023-01-25T03:49:59.000", + "end_time":"2023-01-25T03:49:59.000", + "duration":"P0DT0H0M0S", + "deployment":"snapshots", + "event":"2023-01-24", + "num_frames":1, + "examples":[ + { + "id":91, + "source_image_id":402, + "source_image_path":"2023_01_24\/257-20230125034959-snapshot.jpg", + "source_image_width":4096, + "source_image_height":2160, + "source_image_filesize":1276685, + "label":"Azochis rufidiscalis", + "score":0.4857344627, + "cropped_image_path":"\/media\/michael\/ZWEIBEL\/ami-ml-data\/trapdata\/crops\/820709c454b529d5cf44e59fea1f4b5b.jpg", + "sequence_id":"20230124-SEQ-91", + "timestamp":"2023-01-25T03:49:59.000", + "bbox":[ + 2191, + 413, + 2568, + 638 + ] + } + ], + "example_crop":null + }, + { + "id":"SEQ-86", + "label":"Sphinx canadensis", + "best_score":0.4561957121, + "start_time":"2023-01-24T20:11:59.000", + "end_time":"2023-01-24T20:11:59.000", + "duration":"P0DT0H0M0S", + "deployment":"snapshots", + "event":"2023-01-24", + "num_frames":1, + "examples":[ + { + "id":86, + "source_image_id":88, + "source_image_path":"2023_01_24\/55-20230124201159-snapshot.jpg", + "source_image_width":4096, + "source_image_height":2160, + "source_image_filesize":1013757, + "label":"Sphinx canadensis", + "score":0.4561957121, + "cropped_image_path":"\/media\/michael\/ZWEIBEL\/ami-ml-data\/trapdata\/crops\/839fd6565461939ef946751b87003eda.jpg", + "sequence_id":"20230124-SEQ-86", + "timestamp":"2023-01-24T20:11:59.000", + "bbox":[ + 1629, + 0, + 1731, + 25 + ] + } + ], + "example_crop":null + }, + """ + + help = "Import trap data from AMI data manager occurrences.json file" + + def add_arguments(self, parser): + parser.add_argument("occurrences", type=str) + parser.add_argument("project_id", type=str, help="Project to import to") + + def handle(self, *args, **options): + occurrences = json.load(open(options["occurrences"])) + project_id = options["project_id"] + + project = Project.objects.get(pk=project_id) + self.stdout.write(self.style.SUCCESS('Importing to project "%s"' % project.name)) + + """ + -) Collect all Deployments that need to be created or fetched + -) Collect all SourceImages that need to be created or fetched + -) Collect all Occurrences that need to be created or fetched + -) Create Deployments, linking them to the correct Project + -) Create SourceImages, linking them to the correct Occurrence and Deployment + -) Create Occurrences, linking them to the correct Deployment and Project + -) Generate events (save deployments to trigger event generation) + -) Create Detections, linking them to the correct Occurrence and SourceImage + -) Create Classifications, linking them to the correct Detection and Taxon + -) commit transaction, if transaction is possible + """ + + # Create a fallback algorithm for detections missing algorithm info + default_classification_algorithm, created = Algorithm.objects.get_or_create( + name="Unknown classifier from ADC import", + task_type="classification", + defaults={ + "description": "Unknown classification model imported from AMI data companion occurrences.json", + "version": 0, + }, + ) + if created: + self.stdout.write( + self.style.SUCCESS('Created fallback algorithm "%s"' % default_classification_algorithm.name) + ) + default_detection_algorithm, created = Algorithm.objects.get_or_create( + name="Unknown object detector from ADC import", + task_type="localization", + defaults={ + "description": "Unknown object detection model imported from AMI data companion occurrences.json", + "version": 0, + }, + ) + + # Process each occurrence from the JSON file + for occurrence_data in occurrences: + # Get or create deployment + deployment, created = Deployment.objects.get_or_create( + name=occurrence_data["deployment"], + project=project, + ) + if created: + self.stdout.write(self.style.SUCCESS('Successfully created deployment "%s"' % deployment)) + + # Get or create taxon for the occurrence + best_taxon, created = Taxon.objects.get_or_create(name=occurrence_data["label"]) + if created: + self.stdout.write(self.style.SUCCESS('Successfully created taxon "%s"' % best_taxon)) + + # Create occurrence + occurrence = Occurrence.objects.create( + event=None, # will be assigned when events are grouped + deployment=deployment, + project=project, + determination=best_taxon, + determination_score=occurrence_data["best_score"], + ) + self.stdout.write(self.style.SUCCESS('Successfully created occurrence "%s"' % occurrence)) + + # Process each detection example in the occurrence + for example in occurrence_data["examples"]: + try: + # Create or get source image + image, created = SourceImage.objects.get_or_create( + path=example["source_image_path"], + deployment=deployment, + defaults={ + "timestamp": parse_date(example["timestamp"]), + "event": None, # will be assigned when events are calculated + "project": project, + "width": example["source_image_width"], + "height": example["source_image_height"], + "size": example["source_image_filesize"], + }, + ) + if created: + self.stdout.write(self.style.SUCCESS('Successfully created image "%s"' % image)) + + except KeyError as e: + self.stdout.write(self.style.ERROR('Error creating image - missing field: "%s"' % e)) + continue + + # Create detection + detection, created = Detection.objects.get_or_create( + occurrence=occurrence, + source_image=image, + bbox=example["bbox"], + defaults={ + "path": example.get("cropped_image_path"), + "timestamp": parse_date(example["timestamp"]), + }, + ) + if created: + self.stdout.write(self.style.SUCCESS('Successfully created detection "%s"' % detection)) + + # Get or create taxon for this specific detection + detection_taxon, created = Taxon.objects.get_or_create(name=example["label"]) + if created: + self.stdout.write(self.style.SUCCESS('Successfully created taxon "%s"' % detection_taxon)) + + # Determine which algorithm to use + algorithm_to_use = default_classification_algorithm + if example.get("classification_algorithm"): + # Try to find an algorithm with this name + try: + algorithm_to_use = Algorithm.objects.get(name=example["classification_algorithm"]) + except Algorithm.DoesNotExist: + # Create new algorithm if it doesn't exist + algorithm_to_use, created = Algorithm.objects.get_or_create( + name=example["classification_algorithm"], + task_type="classification", + defaults={ + "description": "Algorithm imported from AMI data companion: " + f"{example['classification_algorithm']}", + "version": 0, + }, + ) + if created: + self.stdout.write(self.style.SUCCESS('Created algorithm "%s"' % algorithm_to_use.name)) + + # Create classification + classification, created = Classification.objects.get_or_create( + detection=detection, + algorithm=algorithm_to_use, + taxon=detection_taxon, + defaults={ + "score": example["score"], + "timestamp": parse_date(example["timestamp"]), + "terminal": True, + }, + ) + if created: + self.stdout.write(self.style.SUCCESS('Successfully created classification "%s"' % classification)) + + # Regroup images into events for all deployments that were modified + self.stdout.write(self.style.SUCCESS("Regrouping images into events...")) + deployments_to_update = Deployment.objects.filter(project=project) + for deployment in deployments_to_update: + deployment.save(regroup_async=False) + self.stdout.write(self.style.SUCCESS('Updated events for deployment "%s"' % deployment)) + + # Update event timestamps + events_updated = 0 + for event in Event.objects.filter(project=project): + event.save() + events_updated += 1 + + self.stdout.write(self.style.SUCCESS("Updated %d events" % events_updated)) + self.stdout.write(self.style.SUCCESS("Import completed successfully!")) diff --git a/ami/exports/tests/__init__.py b/ami/exports/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/exports/tests/fixtures/api_occurrences_1754596036_batch_001.json b/ami/exports/tests/fixtures/api_occurrences_1754596036_batch_001.json new file mode 100644 index 000000000..43f1d4dc5 --- /dev/null +++ b/ami/exports/tests/fixtures/api_occurrences_1754596036_batch_001.json @@ -0,0 +1,131 @@ +{ + "pipeline": "panama_moths_2023", + "algorithms": { + "fasterrcnn_mobilenet_for_ami_moth_traps_2023": { + "name": "FasterRCNN - MobileNet for AMI Moth Traps 2023", + "key": "fasterrcnn_mobilenet_for_ami_moth_traps_2023", + "description": "Model trained on GBIF images and synthetic data in 2023. Slightly less accurate but much faster than other models.", + "task_type": "localization", + "version": 1, + "version_name": null, + "uri": null, + "category_map": null + }, + "moth_nonmoth_classifier": { + "name": "Moth / Non-Moth Classifier", + "key": "moth_nonmoth_classifier", + "description": "Trained on April 17, 2024", + "task_type": "classification", + "version": 1, + "version_name": null, + "uri": null, + "category_map": null + }, + "global_species_classifier_aug_2024": { + "name": "Global Species Classifier - Aug 2024", + "key": "global_species_classifier_aug_2024", + "description": "Trained on August 28th, 2024 for 29,176 species. https://wandb.ai/moth-ai/global-moth-classifier/runs/h0cuqrbc/overview", + "task_type": "classification", + "version": 1, + "version_name": null, + "uri": null, + "category_map": null + } + }, + "total_time": 0.0, + "source_images": [ + { + "id": "680", + "url": "2023_01_24/329-20230125060859-snapshot.jpg", + "deployment": { + "name": "snapshots", + "key": "snapshots" + } + }, + { + "id": "657", + "url": "2023_01_24/326-20230125055729-snapshot.jpg", + "deployment": { + "name": "snapshots", + "key": "snapshots" + } + } + ], + "detections": [ + { + "source_image_id": "680", + "bbox": { + "x1": 1057.0, + "y1": 1905.0, + "x2": 1287.0, + "y2": 2142.0 + }, + "inference_time": null, + "algorithm": { + "name": "FasterRCNN - MobileNet for AMI Moth Traps 2023", + "key": "fasterrcnn_mobilenet_for_ami_moth_traps_2023" + }, + "timestamp": "2023-01-25 06:08:59", + "crop_image_url": "/media/michael/ZWEIBEL/ami-ml-data/trapdata/crops/b4a2fd7e99f3995cac296f28b7a240a7.jpg", + "classifications": [ + { + "classification": "Megalopyge tharops", + "labels": null, + "scores": [ + 0.5804115533828735 + ], + "logits": [], + "inference_time": null, + "algorithm": { + "name": "Global Species Classifier - Aug 2024", + "key": "global_species_classifier_aug_2024" + }, + "terminal": true, + "timestamp": "2023-01-25 06:08:59" + } + ] + }, + { + "source_image_id": "657", + "bbox": { + "x1": 2915.0, + "y1": 1790.0, + "x2": 3130.0, + "y2": 2062.0 + }, + "inference_time": null, + "algorithm": { + "name": "FasterRCNN - MobileNet for AMI Moth Traps 2023", + "key": "fasterrcnn_mobilenet_for_ami_moth_traps_2023" + }, + "timestamp": "2023-01-25 05:57:29", + "crop_image_url": "/media/michael/ZWEIBEL/ami-ml-data/trapdata/crops/80af0289d7def3ab1278d66842b9c42a.jpg", + "classifications": [ + { + "classification": "Megalopyge tharops", + "labels": null, + "scores": [ + 0.7600370049476624 + ], + "logits": [], + "inference_time": null, + "algorithm": { + "name": "Global Species Classifier - Aug 2024", + "key": "global_species_classifier_aug_2024" + }, + "terminal": true, + "timestamp": "2023-01-25 05:57:29" + } + ] + } + ], + "deployments": [ + { + "name": "snapshots", + "key": "snapshots" + } + ], + "config": { + "example_config_param": null + } +} diff --git a/ami/exports/tests/fixtures/pipeline_response-complete_local_project.json b/ami/exports/tests/fixtures/pipeline_response-complete_local_project.json new file mode 100644 index 000000000..3a5176bc1 --- /dev/null +++ b/ami/exports/tests/fixtures/pipeline_response-complete_local_project.json @@ -0,0 +1,121 @@ +{ + "pipeline": "random-detection-random-species", + "algorithms": { + "random-detector": { + "name": "Random Detector", + "key": "random-detector", + "description": "", + "task_type": "localization", + "version": 1, + "version_name": null, + "uri": null, + "category_map": null + }, + "random-species-classifier": { + "name": "Random species classifier", + "key": "random-species-classifier", + "description": "", + "task_type": "classification", + "version": 1, + "version_name": null, + "uri": null, + "category_map": null + } + }, + "total_time": 0.0, + "source_images": [ + { + "id": "680", + "url": "2023_01_24/329-20230125060859-snapshot.jpg", + "deployment": { + "name": "Mothra-01", + "key": "mothra-01" + } + }, + { + "id": "657", + "url": "2023_01_24/326-20230125055729-snapshot.jpg", + "deployment": { + "name": "Mothra-01", + "key": "mothra-01" + } + } + ], + "detections": [ + { + "source_image_id": "680", + "bbox": { + "x1": 1057.0, + "y1": 1905.0, + "x2": 1287.0, + "y2": 2142.0 + }, + "inference_time": null, + "algorithm": { + "name": "Random Detector", + "key": "random-detector" + }, + "timestamp": "2023-01-25 06:08:59", + "crop_image_url": "/media/michael/ZWEIBEL/ami-ml-data/trapdata/crops/b4a2fd7e99f3995cac296f28b7a240a7.jpg", + "classifications": [ + { + "classification": "Megalopyge tharops", + "labels": null, + "scores": [ + 0.5804115533828735 + ], + "logits": [], + "inference_time": null, + "algorithm": { + "name": "Random species classifier", + "key": "random-species-classifier" + }, + "terminal": true, + "timestamp": "2023-01-25 06:08:59" + } + ] + }, + { + "source_image_id": "657", + "bbox": { + "x1": 2915.0, + "y1": 1790.0, + "x2": 3130.0, + "y2": 2062.0 + }, + "inference_time": null, + "algorithm": { + "name": "Random Detector", + "key": "random-detector" + }, + "timestamp": "2023-01-25 05:57:29", + "crop_image_url": "/media/michael/ZWEIBEL/ami-ml-data/trapdata/crops/80af0289d7def3ab1278d66842b9c42a.jpg", + "classifications": [ + { + "classification": "Megalopyge tharops", + "labels": null, + "scores": [ + 0.7600370049476624 + ], + "logits": [], + "inference_time": null, + "algorithm": { + "name": "Random species classifier", + "key": "random-species-classifier" + }, + "terminal": true, + "timestamp": "2023-01-25 05:57:29" + } + ] + } + ], + "deployments": [ + { + "name": "Mothra-01", + "key": "mothra-01" + } + ], + "config": { + "example_config_param": null + } +} diff --git a/ami/exports/tests.py b/ami/exports/tests/test_exports.py similarity index 100% rename from ami/exports/tests.py rename to ami/exports/tests/test_exports.py diff --git a/ami/exports/tests/test_imports.py b/ami/exports/tests/test_imports.py new file mode 100644 index 000000000..ff9552a41 --- /dev/null +++ b/ami/exports/tests/test_imports.py @@ -0,0 +1,33 @@ +import logging + +from django.test import TestCase + +logger = logging.getLogger(__name__) + + +class DataImportTests(TestCase): + """ + Test importing from saved PipelineResponse json files. + + Uses fixtures in `ami/exports/tests/fixtures/*.json`. + """ + + def setUp(self): + from ami.main.models import Project, get_or_create_default_device, get_or_create_default_research_site + + # Create an empty project + self.project = Project.objects.create(name="Imported Project", create_defaults=False) + + # For some reason test _tear down_ fails if these don't exist! + # Even if the tests pass. + get_or_create_default_device(self.project) + get_or_create_default_research_site(self.project) + + def test_import_pipeline_results(self): + """Test importing from API JSON.""" + + results_fpath = "ami/exports/tests/fixtures/pipeline_response-complete_local_project.json" + + from django.core.management import call_command + + call_command("import_pipeline_results", results_fpath, project=self.project.pk) diff --git a/ami/main/management/commands/import_trapdata_project.py b/ami/main/management/commands/import_trapdata_project.py deleted file mode 100644 index fa43829a1..000000000 --- a/ami/main/management/commands/import_trapdata_project.py +++ /dev/null @@ -1,140 +0,0 @@ -import datetime -import json - -from dateutil.parser import parse as parse_date -from django.core.management.base import BaseCommand, CommandError # noqa - -from ...models import Algorithm, Classification, Deployment, Detection, Event, Occurrence, Project, SourceImage, Taxon - - -class Command(BaseCommand): - r"""Import trap data from a JSON file exported from the AMI data companion. - - occurrences.json - [ - { - "id":"20220620-SEQ-207259", - "label":"Baileya ophthalmica", - "best_score":0.6794486046, - "start_time":"2022-06-21T09:23:00.000Z", - "end_time":"2022-06-21T09:23:00.000Z", - "duration":"P0DT0H0M0S", - "deployment":"Vermont-Snapshots-Sample", - "event":{ - "id":19, - "day":"2022-06-20T00:00:00.000", - "url":null - }, - "num_frames":1, - "examples":[ - { - "id":207259, - "source_image_id":15050, - "source_image_path":"2022_06_21_snapshots\/20220621052300-301-snapshot.jpg", - "source_image_width":4096, - "source_image_height":2160, - "source_image_filesize":1599836, - "label":"Baileya ophthalmica", - "score":0.6794486046, - "cropped_image_path":"exports\/occurrences_images\/20220620-SEQ-207259-963edb524a59504392d4bec06717857a.jpg", - "sequence_id":"20220620-SEQ-207259", - "timestamp":"2022-06-21T09:23:00.000Z", - "bbox":[ - 3598, - 1074, - 3821, - 1329 - ] - } - ], - "url":null - }, - ] - """ - - help = "Import trap data from AMI data manager occurrences.json file" - - def add_arguments(self, parser): - parser.add_argument("occurrences", type=str) - - def handle(self, *args, **options): - occurrences = json.load(open(options["occurrences"])) - - project, created = Project.objects.get_or_create(name="Default Project") - if created: - self.stdout.write(self.style.SUCCESS('Successfully created project "%s"' % project)) - algorithm, created = Algorithm.objects.get_or_create(name="Latest Model", version="1.0") - for occurrence in occurrences: - deployment, created = Deployment.objects.get_or_create( - name=occurrence["deployment"], - project=project, - ) - if created: - self.stdout.write(self.style.SUCCESS('Successfully created deployment "%s"' % deployment)) - - event, created = Event.objects.get_or_create( - start=parse_date(occurrence["event"]["day"]), - deployment=deployment, - ) - if created: - self.stdout.write(self.style.SUCCESS('Successfully created event "%s"' % event)) - - best_taxon, created = Taxon.objects.get_or_create(name=occurrence["label"]) - occ = Occurrence.objects.create( - event=event, - deployment=deployment, - project=project, - determination=best_taxon, - ) - self.stdout.write(self.style.SUCCESS('Successfully created occurrence "%s"' % occ)) - - for example in occurrence["examples"]: - try: - image, created = SourceImage.objects.get_or_create( - path=example["source_image_path"], - timestamp=parse_date(example["timestamp"]), - event=event, - deployment=deployment, - width=example["source_image_width"], - height=example["source_image_height"], - size=example["source_image_filesize"], - ) - if created: - self.stdout.write(self.style.SUCCESS('Successfully created image "%s"' % image)) - except KeyError as e: - self.stdout.write(self.style.ERROR('Error creating image "%s"' % e)) - image = None - - if image: - detection, created = Detection.objects.get_or_create( - occurrence=occ, - source_image=image, - timestamp=parse_date(example["timestamp"]), - path=example["cropped_image_path"], - bbox=example["bbox"], - ) - if created: - self.stdout.write(self.style.SUCCESS('Successfully created detection "%s"' % detection)) - else: - detection = None - - taxon, created = Taxon.objects.get_or_create(name=example["label"]) - - if detection: - one_day_later = datetime.timedelta(seconds=60 * 60 * 24) - classification, created = Classification.objects.get_or_create( - score=example["score"], - determination=taxon, - detection=detection, - type="machine", - algorithm=algorithm, - timestamp=parse_date(example["timestamp"]) + one_day_later, - ) - if created: - self.stdout.write( - self.style.SUCCESS('Successfully created classification "%s"' % classification) - ) - - # Update event start and end times based on the first and last detections - for event in Event.objects.all(): - event.save() diff --git a/ami/main/models.py b/ami/main/models.py index 88da476ae..f44f9b038 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -111,16 +111,32 @@ def get_or_create_default_research_site(project: "Project") -> "Site": def get_or_create_default_deployment( - project: "Project", site: "Site | None" = None, device: "Device | None" = None + project: "Project", + site: "Site | None" = None, + device: "Device | None" = None, + name: str = "Default Deployment", ) -> "Deployment": - """Create a default deployment for a project.""" - deployment, _created = Deployment.objects.get_or_create( - name="Default Station", - project=project, - research_site=site, - device=device, + """ + Create a default deployment for a project. + + @TODO Require that the deployment name is unique per project. + """ + deployment = ( + Deployment.objects.filter( + project=project, + name=name, + ) + .order_by("-created_at") + .first() ) - logger.info(f"Created default deployment for project {project}") + if not deployment: + deployment = Deployment.objects.create( + name=name, + project=project, + research_site=site, + device=device, + ) + logger.info(f"Created default deployment for project {project}") return deployment @@ -1166,12 +1182,25 @@ def group_images_into_events( defaults={"start": start_date, "end": end_date}, ) events.append(event) - SourceImage.objects.filter(deployment=deployment, timestamp__in=group).update(event=event) + source_images = SourceImage.objects.filter(deployment=deployment, timestamp__in=group) + source_images.update(event=event) + event.save() # Update start and end times and other cached fields logger.info( f"Created/updated event {event} with {len(group)} images for deployment {deployment}. " f"Duration: {event.duration_label()}" ) + # Update occurrences to point to the new event + occurrences_updated = ( + Occurrence.objects.filter( + detections__source_image__in=source_images, + ) + .exclude(event=event) + .update(event=event) + ) + logger.info( + f"Updated {occurrences_updated} occurrences to point to event {event} for deployment {deployment}." + ) logger.info( f"Done grouping {len(image_timestamps)} captures into {len(events)} events " f"for deployment {deployment}" @@ -1186,9 +1215,21 @@ def group_images_into_events( logger.info(f"Setting image dimensions for event {event}") set_dimensions_for_collection(event) + # Warn if any occurrences belonging to the deployment are not assigned to an event + logger.info("Checking for ungrouped occurrences in deployment") + ungrouped_occurrences = Occurrence.objects.filter( + deployment=deployment, + event__isnull=True, + ) + if ungrouped_occurrences.exists(): + logger.warning( + f"Found {ungrouped_occurrences.count()} occurrences in deployment {deployment} " + "that are not assigned to any event. " + "This may indicate that some images were not grouped correctly." + ) + logger.info("Updating relevant cached fields on deployment") - deployment.events_count = len(events) - deployment.save(update_calculated_fields=False, update_fields=["events_count"]) + deployment.update_calculated_fields(save=True) audit_event_lengths(deployment) diff --git a/ami/ml/management/commands/import_pipeline_results.py b/ami/ml/management/commands/import_pipeline_results.py new file mode 100644 index 000000000..1d8946ee9 --- /dev/null +++ b/ami/ml/management/commands/import_pipeline_results.py @@ -0,0 +1,120 @@ +import json +from pathlib import Path + +from django.core.management.base import BaseCommand, CommandError +from django.db import transaction + +from ami.main.models import Project +from ami.ml.models.pipeline import save_results +from ami.ml.schemas import PipelineResultsResponse + + +class Command(BaseCommand): + help = "Import pipeline results from a JSON file into the database" + + def add_arguments(self, parser): + parser.add_argument("json_file", type=str, help="Path to JSON file containing PipelineResultsResponse data") + parser.add_argument("--project", type=int, required=True, help="Project ID to import the data into") + parser.add_argument("--dry-run", action="store_true", help="Validate the data without saving to database") + parser.add_argument( + "--public-base-url", + type=str, + help="Base URL for images if paths are relative (e.g., http://0.0.0.0:7070/)", + ) + + def handle(self, *args, **options): + json_file_path = Path(options["json_file"]) + project_id = options["project"] + dry_run = options.get("dry_run", False) + public_base_url = options.get("public_base_url") + + # Validate that the JSON file exists + if not json_file_path.exists(): + raise CommandError(f"JSON file does not exist: {json_file_path}") + + # Validate that the project exists + try: + project = Project.objects.get(pk=project_id) + except Project.DoesNotExist: + raise CommandError(f"Project with ID {project_id} does not exist") + + self.stdout.write(f"Reading JSON file: {json_file_path}") + + # Read and parse the JSON file + try: + with open(json_file_path, encoding="utf-8") as f: + json_data = json.load(f) + except json.JSONDecodeError as e: + raise CommandError(f"Invalid JSON in file {json_file_path}: {e}") + except Exception as e: + raise CommandError(f"Error reading file {json_file_path}: {e}") + + # Validate the JSON data against the PipelineResultsResponse schema + try: + pipeline_results = PipelineResultsResponse(**json_data) + except Exception as e: + raise CommandError(f"Invalid PipelineResultsResponse data: {e}") + + self.stdout.write( + self.style.SUCCESS( + f"Successfully validated PipelineResultsResponse with:" + f"\n - Pipeline: {pipeline_results.pipeline}" + f"\n - Source images: {len(pipeline_results.source_images)}" + f"\n - Detections: {len(pipeline_results.detections)}" + f"\n - Algorithms: {len(pipeline_results.algorithms)}" + ) + ) + + if dry_run: + self.stdout.write(self.style.WARNING("Dry run mode - no data will be saved to database")) + return + + # Import the data using save_results function + self.stdout.write(f"Importing data into project: {project} (ID: {project_id})") + + try: + # Call the save_results function with create_missing_source_images=True + results_json = pipeline_results.json() + with transaction.atomic(): + result = save_results( + results_json=results_json, + job_id=None, + return_created=True, + create_missing_source_images=True, + project_id=project_id, + public_base_url=public_base_url, + ) + + if result: + self.stdout.write( + self.style.SUCCESS( + f"Successfully imported pipeline results:" + f"\n - Pipeline: {result.pipeline}" + f"\n - Source images processed: {len(result.source_images)}" + f"\n - Detections created: {len(result.detections)}" + f"\n - Classifications created: {len(result.classifications)}" + f"\n - Algorithms used: {len(result.algorithms)}" + f"\n - Deployments used: {len(result.deployments)}" + f"\n - Total processing time: {result.total_time:.2f} seconds" + ) + ) + + # Re-save all deployments in the results to ensure they are up-to-date + # Must loop through the source images + self.stdout.write(self.style.SUCCESS("Updating sessions and stations")) + deployments = { + source_image.deployment for source_image in result.source_images if source_image.deployment + } + for deployment in deployments: + deployment.save(regroup_async=False) + else: + self.stdout.write(self.style.WARNING("Import completed but no result object returned")) + + except Exception as e: + raise CommandError(f"Error importing pipeline results: {e}") + + self.stdout.write( + self.style.SUCCESS( + f"Pipeline results successfully imported into project '{project.name}' (ID: {project_id})" + ) + ) diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index ed014f779..d2e9d6b0c 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -41,6 +41,7 @@ AlgorithmConfigResponse, AlgorithmReference, ClassificationResponse, + DeploymentResponse, DetectionRequest, DetectionResponse, PipelineRequest, @@ -345,11 +346,12 @@ def get_or_create_algorithm_and_category_map( " Will attempt to create one from the classification results." ) + # @TODO update the unique constraint to use key & version instead of name & version algo, _created = Algorithm.objects.get_or_create( - key=algorithm_config.key, + name=algorithm_config.name, version=algorithm_config.version, defaults={ - "name": algorithm_config.name, + "key": algorithm_config.key, "task_type": algorithm_config.task_type, "version_name": algorithm_config.version_name, "uri": algorithm_config.uri, @@ -415,7 +417,7 @@ def get_or_create_detection( # A detection may have a pre-existing crop image URL or not. # If not, a new one will be created in a periodic background task. - if detection_resp.crop_image_url and detection_resp.crop_image_url.strip("/"): + if detection_resp.crop_image_url and detection_resp.crop_image_url.startswith(("http://", "https://")): # Ensure that the crop image URL is not empty or only a slash. None is fine. crop_url = detection_resp.crop_image_url else: @@ -729,6 +731,113 @@ def create_classifications( return existing_classifications + new_classifications +def get_or_create_deployments( + deployments_data: list[DeploymentResponse], + project_id: int, + logger: logging.Logger = logger, +) -> dict[str, Deployment]: + """ + Create or get deployments from source images data. + + :param source_images_data: List of source image dictionaries from raw JSON + :param project_id: Project ID to create deployments for + :param logger: Logger instance + + :return: Dictionary mapping deployment keys to Deployment objects + """ + from ami.main.models import Project, get_or_create_default_deployment + + project = Project.objects.get(pk=project_id) + deployments = {} + + for deployment_data in deployments_data: + deployment_name = deployment_data.name + + if deployment_name not in deployments: + deployment = get_or_create_default_deployment( + project=project, + name=deployment_name, + ) + deployments[deployment_name] = deployment + + return deployments + + +def create_source_images( + source_images_data: list[SourceImageResponse], + deployments: dict[str, Deployment], + project_id: int, + public_base_url: str | None = None, + logger: logging.Logger = logger, +) -> dict[str, int]: + """ + Create source images from pipeline results data. + + This assumes the IDs are external IDs from the pipeline results creator + and maps them to internal IDs in the database. + + This was created for an initial use case, needs to be tested for broader use cases. + + :param source_images_data: List of source image dictionaries from raw JSON + :param deployments: Dictionary mapping deployment keys to Deployment objects + :param project_id: Project ID + :param public_base_url: Base URL for images if paths are relative + :param logger: Logger instance + + :return: Dictionary mapping external IDs to internal source image IDs + """ + import ami.utils.dates + from ami.main.models import Project + + project = Project.objects.get(pk=project_id) + id_mapping = {} + + for source_image_data in source_images_data: + external_id = source_image_data.id + url = source_image_data.url + deployment_info = source_image_data.deployment + + if not deployment_info: + logger.warning( + f"The incoming source image {external_id} does not have a deployment specified. " + "This is required to create a SourceImage." + ) + continue + else: + deployment_name = deployment_info.name + deployment = deployments[deployment_name] + + # Check if source image already exists by URL and deployment + existing_image = SourceImage.objects.filter(deployment=deployment, path=url).first() + + if existing_image: + source_image = existing_image + logger.debug(f"Using existing source image {source_image.pk} for path: {url}") + else: + # Extract timestamp from filename + timestamp = ami.utils.dates.get_image_timestamp_from_filename(url) + + # Set public_base_url if provided and path is relative + final_public_base_url = None + if public_base_url and not url.startswith(("http://", "https://")): + final_public_base_url = public_base_url.rstrip("/") + + source_image = SourceImage.objects.create( + path=url, + deployment=deployment, + project=project, + timestamp=timestamp, + public_base_url=final_public_base_url, + ) + logger.info(f"Created source image {source_image.pk} for path: {url}") + + # Map external ID to internal ID + id_mapping[external_id] = str(source_image.pk) + logger.debug(f"Mapped external ID {external_id} to internal ID {source_image.pk}") + + return id_mapping + + def create_and_update_occurrences_for_detections( detections: list[Detection], logger: logging.Logger = logger, @@ -802,6 +911,7 @@ class PipelineSaveResults: detections: list[Detection] classifications: list[Classification] algorithms: dict[str, Algorithm] + deployments: dict[str, Deployment] total_time: float @@ -811,6 +921,10 @@ def save_results( results_json: str | None = None, job_id: int | None = None, return_created=False, + create_missing_source_images: bool = False, + project_id: int | None = None, + public_base_url: str | None = None, + create_new_algorithms: bool = False, ) -> PipelineSaveResults | None: """ Save results from ML pipeline API. @@ -826,7 +940,9 @@ def save_results( pipeline, _created = Pipeline.objects.get_or_create(slug=results.pipeline, defaults={"name": results.pipeline}) if _created: logger.warning(f"Pipeline choice returned by the Processing Service was not recognized! {pipeline}") - algorithms_used = set() + + algorithms_used: dict[str, Algorithm] = {} + deployments_used: dict[str, Deployment] = {} job_logger = logger start_time = time.time() @@ -844,6 +960,57 @@ def save_results( results = PipelineResultsResponse.parse_obj(results.dict()) assert results, "No results from pipeline to save" + + # Create missing source images and deployments if requested + if create_missing_source_images and project_id: + job_logger.info(f"Creating missing source images and deployments for project {project_id}") + + deployments_data = results.deployments + source_images_data = results.source_images + + if not deployments_data: + job_logger.warning( + "No deployments data found in results. " + "New source images will not be created without deployments data." + ) + else: + deployments_used = get_or_create_deployments( + deployments_data=deployments_data, + project_id=project_id, + logger=job_logger, + ) + deployments_map = {dep.name: dep for dep in Deployment.objects.filter(project_id=project_id)} + job_logger.info(f"Found {len(deployments_map)} existing deployments for project {project_id}") + + if not source_images_data: + raise ValueError( + "No source images data found in results. " + "New detections cannot be created without source images data." + ) + + # Create source images from the external results data + # where the IDs do not match the internal IDs. + id_mapping = create_source_images( + source_images_data=source_images_data, + deployments=deployments_map, + project_id=project_id, + public_base_url=public_base_url, + logger=job_logger, + ) + + # Update the results to use internal IDs + for i, source_image_data in enumerate(source_images_data): + external_id = source_image_data.id + if external_id in id_mapping: + results.source_images[i].id = str(id_mapping[external_id]) + + # Update detection source_image_ids to use internal IDs + for detection in results.detections: + if detection.source_image_id in id_mapping: + detection.source_image_id = str(id_mapping[detection.source_image_id]) + + job_logger.debug(f"Created/found {len(id_mapping)} source images with ID mapping: {id_mapping}") + source_images = SourceImage.objects.filter(pk__in=[int(img.id) for img in results.source_images]).distinct() pipeline, _created = Pipeline.objects.get_or_create(slug=results.pipeline, defaults={"name": results.pipeline}) @@ -852,19 +1019,23 @@ def save_results( f"The pipeline returned by the ML backend was not recognized, created a placeholder: {pipeline}" ) - # Algorithms and category maps should be created in advance when registering the pipeline & processing service - # however they are also currently available in each pipeline results response as well. - # @TODO review if we should only use the algorithms from the pre-registered pipeline config instead of the results - algorithms_used = { - algo_key: get_or_create_algorithm_and_category_map(algo_config, logger=job_logger) - for algo_key, algo_config in results.algorithms.items() - } - # Add all algorithms initially reported in the pipeline response to the pipeline - for algo in algorithms_used.values(): - pipeline.algorithms.add(algo) + if create_new_algorithms: + # Algorithms and category maps should be created in advance when registering the pipeline & processing service + # however they are also currently available in each pipeline results response as well. + # @TODO review if we should only use the algorithms from the pre-registered pipeline config instead of + # the results + algorithms_used = { + algo_key: get_or_create_algorithm_and_category_map(algo_config, logger=job_logger) + for algo_key, algo_config in results.algorithms.items() + } + # Add all algorithms initially reported in the pipeline response to the pipeline + for algo in algorithms_used.values(): + pipeline.algorithms.add(algo) - algos_reported = [f" {algo.task_type}: {algo_key} ({algo})\n" for algo_key, algo in algorithms_used.items()] - job_logger.info(f"Algorithms reported in pipeline response: \n{''.join(algos_reported)}") + algos_reported = [f" {algo.task_type}: {algo_key} ({algo})\n" for algo_key, algo in algorithms_used.items()] + job_logger.info(f"Algorithms reported in pipeline response: \n{''.join(algos_reported)}") + else: + algorithms_used = {algo.key: algo for algo in pipeline.algorithms.all()} detections = create_detections( detections=results.detections, @@ -931,6 +1102,7 @@ def save_results( detections=detections, classifications=classifications, algorithms=algorithms_used, + deployments=deployments_used, total_time=total_time, ) @@ -1107,7 +1279,7 @@ def choose_processing_service_for_pipeline( return processing_service_lowest_latency def process_images(self, images: typing.Iterable[SourceImage], project_id: int, job_id: int | None = None): - processing_service = self.choose_processing_service_for_pipeline(job_id, self.name, project_id) + processing_service = self.choose_processing_service_for_pipeline(job_id or 0, self.name, project_id) if not processing_service.endpoint_url: raise ValueError( diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 7f5a5c9a9..9fa87e7fe 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -119,9 +119,16 @@ class SourceImageRequest(pydantic.BaseModel): # b64: str | None = None +class DeploymentResponse(pydantic.BaseModel): + id: str | None = None + name: str + key: str | None = None + + class SourceImageResponse(pydantic.BaseModel): id: str url: str + deployment: DeploymentResponse | None = None class Config: extra = "ignore" @@ -188,6 +195,7 @@ class PipelineResultsResponse(pydantic.BaseModel): total_time: float source_images: list[SourceImageResponse] detections: list[DetectionResponse] + deployments: list[DeploymentResponse] | None = None errors: list | str | None = None diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 9f133a953..e4c501ada 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -5,7 +5,7 @@ from rest_framework.test import APIRequestFactory, APITestCase from ami.base.serializers import reverse_with_params -from ami.main.models import Classification, Detection, Project, SourceImage, SourceImageCollection +from ami.main.models import Detection, Project, SourceImage, SourceImageCollection from ami.ml.models import Algorithm, Pipeline, ProcessingService from ami.ml.models.pipeline import collect_images, get_or_create_algorithm_and_category_map, save_results from ami.ml.schemas import ( @@ -481,6 +481,13 @@ def test_skip_existing_per_batch_during_processing(self): pass def test_unknown_algorithm_returned_by_processing_service(self): + """ + Test that unknown algorithms returned by the processing service are handled correctly. + + Previously we allowed unknown algorithms to be returned by the pipeline, + now all algorithms must be registered first from the processing service's /info + endpoint. + """ fake_results = self.fake_pipeline_results(self.test_images, self.pipeline) new_detector = AlgorithmConfigResponse( @@ -501,92 +508,18 @@ def test_unknown_algorithm_returned_by_processing_service(self): current_total_algorithm_count = Algorithm.objects.count() - # @TODO assert a warning was logged - save_results(fake_results) + # Ensure an exception is raised that a new algorithm was not + # pre-registered from the /info endpoint + with self.assertRaises(ValueError): + save_results(fake_results) - # Ensure new algorithms were added to the database + # Ensure no new algorithms were added to the database new_algorithm_count = Algorithm.objects.count() - self.assertEqual(new_algorithm_count, current_total_algorithm_count + 2) + self.assertEqual(new_algorithm_count, current_total_algorithm_count) # Ensure new algorithms were also added to the pipeline - self.assertTrue(self.pipeline.algorithms.filter(name=new_detector.name, key=new_detector.key).exists()) - self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier.name, key=new_classifier.key).exists()) - - @unittest.skip("Not implemented yet") - def test_reprocessing_after_unknown_algorithm_added(self): - # @TODO fix issue with "None" algorithm on some detections - - images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) - - save_results(self.fake_pipeline_results(images, self.pipeline)) - - new_detector = AlgorithmConfigResponse( - name="Unknown Detector 5.1b-mobile", key="unknown-detector", task_type="detection" - ) - new_classifier = AlgorithmConfigResponse( - name="Unknown Classifier 3.0b-mega", key="unknown-classifier", task_type="classification" - ) - - fake_results = self.fake_pipeline_results(images, self.pipeline) - - # Change the algorithm names to unknown ones - for detection in fake_results.detections: - detection.algorithm = AlgorithmReference(name=new_detector.name, key=new_detector.key) - - for classification in detection.classifications: - classification.algorithm = AlgorithmReference(name=new_classifier.name, key=new_classifier.key) - - fake_results.algorithms[new_detector.key] = new_detector - fake_results.algorithms[new_classifier.key] = new_classifier - - # print("FAKE RESULTS") - # print(fake_results) - # print("END FAKE RESULTS") - - saved_objects = save_results(fake_results, return_created=True) - assert saved_objects is not None - saved_detections = saved_objects.detections - saved_classifications = saved_objects.classifications - - for obj in saved_detections: - assert obj.detection_algorithm # For type checker, not the test - - # Ensure the new detector was used for the detection - self.assertEqual(obj.detection_algorithm.name, new_detector.name) - - # Ensure each detection has classification objects - self.assertTrue(obj.classifications.exists()) - - # Ensure detection has a correct classification object - for classification in obj.classifications.all(): - self.assertIn(classification, saved_classifications) - - for obj in saved_classifications: - assert obj.algorithm # For type checker, not the test - - # Ensure the new classifier was used for the classification - self.assertEqual(obj.algorithm.name, new_classifier.name) - - # Ensure each classification has the correct detection object - self.assertIn(obj.detection, saved_detections, "Wrong detection object for classification object.") - - # Ensure new algorithms were added to the pipeline - self.assertTrue(self.pipeline.algorithms.filter(name=new_detector.name).exists()) - self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier.name).exists()) - - detection_algos_used = Detection.objects.all().values_list("detection_algorithm__name", flat=True).distinct() - self.assertTrue(new_detector.name in detection_algos_used) - # Ensure None is not in the list - self.assertFalse(None in detection_algos_used) - classification_algos_used = Classification.objects.all().values_list("algorithm__name", flat=True) - self.assertTrue(new_classifier.name in classification_algos_used) - # Ensure None is not in the list - self.assertFalse(None in classification_algos_used) - - # The algorithms are new, but they were registered to the pipeline, so the images should be skipped. - images_again = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) - remaining_images_to_process = len(images_again) - self.assertEqual(remaining_images_to_process, 0) + # self.assertTrue(self.pipeline.algorithms.filter(name=new_detector.name, key=new_detector.key).exists()) + # self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier.name, key=new_classifier.key).exists()) def test_yes_reprocess_if_new_terminal_algorithm_same_intermediate(self): """