Skip to content

Commit

Permalink
Localization v2 (#52)
Browse files Browse the repository at this point in the history
* Add Leonard's new localization models

* Add back existing model

* Fix anchor sizes for mobilenet model

* Update new model names & descriptions.

* Use new object detector in tests

* Update expected test results
  • Loading branch information
mihow authored Aug 8, 2023
1 parent 666fc8a commit 868d2d2
Show file tree
Hide file tree
Showing 5 changed files with 568 additions and 201 deletions.
2 changes: 1 addition & 1 deletion trapdata/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def slugify(s):
# Quick method to make an acceptable attribute name or url part from a title
# install python-slugify for handling unicode chars, numbers at the beginning, etc.
separator = "_"
acceptable_chars = list(string.ascii_letters) + [separator]
acceptable_chars = list(string.ascii_letters) + list(string.digits) + [separator]
return (
"".join(
[
Expand Down
99 changes: 82 additions & 17 deletions trapdata/ml/models/localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import PIL.Image
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.models.detection.anchor_utils
import torchvision.models.detection.backbone_utils
import torchvision.models.detection.faster_rcnn
import torchvision.models.mobilenetv3

from trapdata import TrapImage, db, logger
from trapdata.db.models.detections import save_detected_objects
Expand Down Expand Up @@ -147,7 +150,7 @@ def save_results(self, item_ids, batch_output):
)


class MothObjectDetector_FasterRCNN(ObjectDetector):
class MothObjectDetector_FasterRCNN_2021(ObjectDetector):
name = "FasterRCNN for AMI Moth Traps 2021"
weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/v1_localizmodel_2021-08-17-12-06.pt"
description = (
Expand All @@ -160,7 +163,11 @@ def get_model(self):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
num_classes = 2 # 1 class (object) + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = (
torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
in_features, num_classes
)
)
logger.debug(f"Loading weights: {self.weights}")
checkpoint = torch.load(self.weights, map_location=self.device)
state_dict = checkpoint.get("model_state_dict") or checkpoint
Expand All @@ -186,34 +193,92 @@ def post_process_single(self, output):
return bboxes


class GenericObjectDetector_FasterRCNN_MobileNet(ObjectDetector):
name = "Pre-trained FasterRCNN with MobileNet backend"
class MothObjectDetector_FasterRCNN_2023(ObjectDetector):
name = "FasterRCNN for AMI Moth Traps 2023"
weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/fasterrcnn_resnet50_fpn_tz53qv9v.pt"
description = (
"Faster version of FasterRCNN but not trained on moth trap data. "
"Produces multiple overlapping bounding boxes. But helpful for testing on CPU machines."
"Model trained on GBIF images and synthetic data in 2023. "
"Accurate but can be slow on a machine without GPU."
)
bbox_score_threshold = 0.01
bbox_score_threshold = 0.80

def get_model(self):
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
weights="DEFAULT"
num_classes = 2 # 1 class (object) + background
logger.debug(f"Loading weights: {self.weights}")
model = torchvision.models.get_model(
name="fasterrcnn_resnet50_fpn",
num_classes=num_classes,
pretrained=False,
)
# @TODO can I use load_state_dict here with weights="DEFAULT"?
checkpoint = torch.load(self.weights, map_location=self.device)
state_dict = checkpoint.get("model_state_dict") or checkpoint
model.load_state_dict(state_dict)
model = model.to(self.device)
model.eval()
return model
self.model = model
return self.model

def post_process_single(self, output):
# This model does not use the labels from the object detection model
_ = output["labels"]
assert all([label == 1 for label in output["labels"]])

# Filter out objects if their score is under score threshold
bboxes = output["boxes"][
(output["scores"] > self.bbox_score_threshold) & (output["labels"] > 1)
]
bboxes = output["boxes"][output["scores"] > self.bbox_score_threshold]

# Filter out background label, if using pretrained model only!
bboxes = output["boxes"][output["labels"] > 1]
logger.debug(
f"Keeping {len(bboxes)} out of {len(output['boxes'])} objects found (threshold: {self.bbox_score_threshold})"
)

bboxes = bboxes.cpu().numpy().astype(int).tolist()
return bboxes


class MothObjectDetector_FasterRCNN_MobileNet_2023(ObjectDetector):
name = "FasterRCNN - MobileNet for AMI Moth Traps 2023"
weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/localization/fasterrcnn_mobilenet_v3_large_fpn_uqfh7u9w.pt"
description = (
"Model trained on GBIF images and synthetic data in 2023. "
"Slightly less accurate but much faster than other models."
)
bbox_score_threshold = 0.50
trainable_backbone_layers = 6 # all layers are trained
anchor_sizes = (64, 128, 256, 512)
num_classes = 2

def get_model(self):
norm_layer = torch.nn.BatchNorm2d
backbone = torchvision.models.mobilenetv3.mobilenet_v3_large(
weights=None, norm_layer=norm_layer
)
backbone = torchvision.models.detection.backbone_utils._mobilenet_extractor(
backbone, True, self.trainable_backbone_layers
)
anchor_sizes = (self.anchor_sizes,) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
model = torchvision.models.detection.faster_rcnn.FasterRCNN(
backbone,
self.num_classes,
rpn_anchor_generator=torchvision.models.detection.anchor_utils.AnchorGenerator(
anchor_sizes, aspect_ratios
),
rpn_score_thresh=0.05,
)
checkpoint = torch.load(self.weights, map_location=self.device)
state_dict = checkpoint.get("model_state_dict") or checkpoint
model.load_state_dict(state_dict)
model = model.to(self.device)
model.eval()
self.model = model
return self.model

def post_process_single(self, output):
# This model does not use the labels from the object detection model
_ = output["labels"]
assert all([label == 1 for label in output["labels"]])

# Filter out objects if their score is under score threshold
bboxes = output["boxes"][output["scores"] > self.bbox_score_threshold]

logger.debug(
f"Keeping {len(bboxes)} out of {len(output['boxes'])} objects found (threshold: {self.bbox_score_threshold})"
Expand Down
Loading

0 comments on commit 868d2d2

Please sign in to comment.