Skip to content

Commit 94c77d3

Browse files
committed
Add class definition for convnext
1 parent 0696f1e commit 94c77d3

File tree

1 file changed

+57
-3
lines changed

1 file changed

+57
-3
lines changed

trapdata/ml/models/classification.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,62 @@ def forward(self, x):
107107
return x
108108

109109

110-
class ConvNeXtClassifier(InferenceBaseClass):
111-
pass
110+
class ConvNeXtOrderClassifier(InferenceBaseClass):
111+
"""ConvNeXt based insect order classifier"""
112+
input_size = 128
113+
114+
def get_model(self):
115+
num_classes = len(self.category_map)
116+
model = timm.create_model(
117+
"convnext_tiny.fb_in22k",
118+
weights=None,
119+
num_classes=num_classes,
120+
)
121+
model = model.to(self.device)
122+
checkpoint = torch.load(self.weights, map_location=self.device)
123+
# The model state dict is nested in some checkpoints, and not in others
124+
state_dict = checkpoint.get("model_state_dict") or checkpoint
125+
126+
model.load_state_dict(state_dict)
127+
model.eval()
128+
return model
129+
130+
131+
def _pad_to_square(self):
132+
"""Padding transformation to make the image square"""
133+
134+
width, height = self.image.size
135+
if height < width:
136+
return torchvision.transforms.Pad(padding=[0, 0, 0, width - height])
137+
elif height > width:
138+
return torchvision.transforms.Pad(padding=[0, 0, height - width, 0])
139+
else:
140+
return torchvision.transforms.Pad(padding=[0, 0, 0, 0])
141+
142+
143+
def get_transforms(self):
144+
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
145+
return torchvision.transforms.Compose(
146+
[
147+
self._pad_to_square(),
148+
torchvision.transforms.Resize((self.input_size, self.input_size)),
149+
torchvision.transforms.ToTensor(),
150+
torchvision.transforms.Normalize(mean, std),
151+
]
152+
)
153+
154+
155+
def post_process_batch(self, output):
156+
predictions = torch.nn.functional.softmax(output, dim=1)
157+
predictions = predictions.cpu().numpy()
158+
159+
categories = predictions.argmax(axis=1)
160+
labels = [self.category_map[cat] for cat in categories]
161+
scores = predictions.max(axis=1).astype(float)
162+
163+
result = list(zip(labels, scores))
164+
logger.debug(f"Post-processing result batch: {result}")
165+
return result
112166

113167

114168
class Resnet50Classifier_Turing(InferenceBaseClass):
@@ -507,7 +561,7 @@ class PanamaMothSpeciesClassifier2024(SpeciesClassifier, Resnet50TimmClassifier)
507561
)
508562

509563

510-
class InsectOrderClassifier2025(SpeciesClassifier, ConvNeXtClassifier):
564+
class InsectOrderClassifier2025(SpeciesClassifier, ConvNeXtOrderClassifier):
511565
name = "Insect Order Classifier"
512566
description = "ConvNeXt-T based insect order classifier for 16 classes trained by Mila in January 2025"
513567
weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/insect_orders/convnext_tiny_in22k_worder0.5_wbinary0.5_run2_checkpoint.pt"

0 commit comments

Comments
 (0)