Skip to content

Commit

Permalink
Continue testing api-only interface
Browse files Browse the repository at this point in the history
  • Loading branch information
mihow committed Nov 8, 2023
1 parent 5a5cc36 commit 5ff5b68
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 18 deletions.
4 changes: 2 additions & 2 deletions scripts/start_db_container.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ set -o errexit
set -o nounset

CONTAINER_NAME=ami-db
HOST_PORT=5432
HOST_PORT=5433
POSTGRES_VERSION=14
POSTGRES_DB=ami

docker run -d -i --name $CONTAINER_NAME -v "$(pwd)/db_data":/var/lib/postgresql/data --restart always -p $HOST_PORT:5432 -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=$POSTGRES_DB postgres:$POSTGRES_VERSION

docker logs ami-db --tail 100

echo 'Database started, Connection string: "postgresql://postgres@localhost:5432/ami"'
echo "Database started, Connection string: \"postgresql://postgres@localhost:${HOST_PORT}/${POSTGRES_DB}\""
echo "Stop (and destroy) database with 'docker stop $CONTAINER_NAME' && docker remove $CONTAINER_NAME"
16 changes: 12 additions & 4 deletions trapdata/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@

logger = logging.getLogger(__name__)

TEMPORARY_DEPLOYMENT_ID = 9
TEMPORARY_EVENT_ID = 34
TEMPORARY_COLLECTION_ID = 4


def save_detected_objects(
source_image_ids: list[int], detected_objects_data: list[dict], *args, **kwargs
):
logger.info(f"Saving {len(source_image_ids)} detected objects via API")
print(f"Saving {len(source_image_ids)} detected objects via API")
responses = {}
path = "detections/"
for source_image_id, detected_objects in zip(
Expand Down Expand Up @@ -38,9 +43,11 @@ def get_next_source_images(num: int, *args, **kwargs) -> list[IncomingSourceImag
path = "captures/"
args = {
"limit": num,
"deployment": 9,
"event": 34,
# "deployment": TEMPORARY_DEPLOYMENT_ID,
# "event": TEMPORARY_EVENT_ID,
"collections": TEMPORARY_COLLECTION_ID,
"has_detections": False,
"order": "?",
} # last_processed__isnull=True
url = settings.api_base_url + path
resp = get_session().get(url, params=args)
Expand All @@ -53,8 +60,9 @@ def get_next_source_images(num: int, *args, **kwargs) -> list[IncomingSourceImag
def get_source_image_count(*args, **kwargs) -> int:
path = "captures/"
args = {
"deployment": 9,
"event": 34,
# "deployment": TEMPORARY_DEPLOYMENT_ID,
# "event": TEMPORARY_EVENT_ID,
"collections": TEMPORARY_COLLECTION_ID,
"limit": 1,
"has_detections": False,
}
Expand Down
38 changes: 29 additions & 9 deletions trapdata/ml/models/localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import torchvision.models.detection.backbone_utils
import torchvision.models.detection.faster_rcnn
import torchvision.models.mobilenetv3
from PIL import ImageFile

from trapdata import TrapImage, db, logger
from trapdata.ml.models.base import InferenceBaseClass
from trapdata.ml.utils import get_or_download_file

ImageFile.LOAD_TRUNCATED_IMAGES = True


class LocalizationIterableDatabaseDataset(torch.utils.data.IterableDataset):
def __init__(self, queue, image_transforms, batch_size=1):
Expand All @@ -34,20 +37,29 @@ def __iter__(self):
records = get_next_source_images(self.batch_size)
logger.debug(f"Pulling records: {records}")
if records:
item_ids = torch.utils.data.default_collate(
[record.id for record in records]
)
batch_data = torch.utils.data.default_collate(
[self.transform(record.url) for record in records]
)
items = [(record.id, self.transform(record.url)) for record in records]
items = [(item_id, item) for item_id, item in items if item is not None]
item_ids, item_objs = zip(*items)
logger.info(f"Procesing items: {item_ids}")

item_ids = torch.utils.data.default_collate(item_ids)
batch_data = torch.utils.data.default_collate(item_objs)
yield (item_ids, batch_data)

def transform(self, url):
url = url + "?width=5000&redirect=False"
logger.info(f"Fetching and transforming: {url}")
img_path = get_or_download_file(url, destination_dir="/tmp/today/")
return self.image_transforms(PIL.Image.open(img_path))
try:
return self.image_transforms(PIL.Image.open(img_path))
except PIL.UnidentifiedImageError:
logger.error(f"Unidentified image: {img_path}")
print(f"Unidentified image: {img_path}")
return None
except OSError:
logger.error(f"OSError: {img_path}")
print(f"OSError: {img_path}")
return None


class LocalizationDatabaseDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -86,8 +98,16 @@ def __getitem__(self, idx):
sesh.commit()

img_path = img_path
pil_image = PIL.Image.open(img_path)
item = (item_id, self.transform(pil_image))
try:
pil_image = PIL.Image.open(img_path)
except FileNotFoundError:
logger.error(f"File not found: {img_path}")
item = (item_id, None)
except PIL.UnidentifiedImageError:
logger.error(f"Unidentified image: {img_path}")
item = (item_id, None)
else:
item = (item_id, self.transform(pil_image))

return item

Expand Down
6 changes: 3 additions & 3 deletions trapdata/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def start_pipeline(
num_workers=settings.num_workers,
single=single,
)
if object_detector.queue.queue_count() > 0:
object_detector.run()
logger.info("Localization complete")
# if object_detector.queue.queue_count() > 0:
object_detector.run()
logger.info("Localization complete")

BinaryClassifier = ml.models.binary_classifiers[
settings.binary_classification_model.value
Expand Down

0 comments on commit 5ff5b68

Please sign in to comment.