@@ -107,8 +107,62 @@ def forward(self, x):
107
107
return x
108
108
109
109
110
- class ConvNeXtClassifier (InferenceBaseClass ):
111
- pass
110
+ class ConvNeXtOrderClassifier (InferenceBaseClass ):
111
+ """ConvNeXt based insect order classifier"""
112
+ input_size = 128
113
+
114
+ def get_model (self ):
115
+ num_classes = len (self .category_map )
116
+ model = timm .create_model (
117
+ "convnext_tiny.fb_in22k" ,
118
+ weights = None ,
119
+ num_classes = num_classes ,
120
+ )
121
+ model = model .to (self .device )
122
+ checkpoint = torch .load (self .weights , map_location = self .device )
123
+ # The model state dict is nested in some checkpoints, and not in others
124
+ state_dict = checkpoint .get ("model_state_dict" ) or checkpoint
125
+
126
+ model .load_state_dict (state_dict )
127
+ model .eval ()
128
+ return model
129
+
130
+
131
+ def _pad_to_square (self ):
132
+ """Padding transformation to make the image square"""
133
+
134
+ width , height = self .image .size
135
+ if height < width :
136
+ return torchvision .transforms .Pad (padding = [0 , 0 , 0 , width - height ])
137
+ elif height > width :
138
+ return torchvision .transforms .Pad (padding = [0 , 0 , height - width , 0 ])
139
+ else :
140
+ return torchvision .transforms .Pad (padding = [0 , 0 , 0 , 0 ])
141
+
142
+
143
+ def get_transforms (self ):
144
+ mean , std = [0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ]
145
+ return torchvision .transforms .Compose (
146
+ [
147
+ self ._pad_to_square (),
148
+ torchvision .transforms .Resize ((self .input_size , self .input_size )),
149
+ torchvision .transforms .ToTensor (),
150
+ torchvision .transforms .Normalize (mean , std ),
151
+ ]
152
+ )
153
+
154
+
155
+ def post_process_batch (self , output ):
156
+ predictions = torch .nn .functional .softmax (output , dim = 1 )
157
+ predictions = predictions .cpu ().numpy ()
158
+
159
+ categories = predictions .argmax (axis = 1 )
160
+ labels = [self .category_map [cat ] for cat in categories ]
161
+ scores = predictions .max (axis = 1 ).astype (float )
162
+
163
+ result = list (zip (labels , scores ))
164
+ logger .debug (f"Post-processing result batch: { result } " )
165
+ return result
112
166
113
167
114
168
class Resnet50Classifier_Turing (InferenceBaseClass ):
@@ -507,7 +561,7 @@ class PanamaMothSpeciesClassifier2024(SpeciesClassifier, Resnet50TimmClassifier)
507
561
)
508
562
509
563
510
- class InsectOrderClassifier2025 (SpeciesClassifier , ConvNeXtClassifier ):
564
+ class InsectOrderClassifier2025 (SpeciesClassifier , ConvNeXtOrderClassifier ):
511
565
name = "Insect Order Classifier"
512
566
description = "ConvNeXt-T based insect order classifier for 16 classes trained by Mila in January 2025"
513
567
weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/insect_orders/convnext_tiny_in22k_worder0.5_wbinary0.5_run2_checkpoint.pt"
0 commit comments