3
3
4
4
import numpy as np
5
5
import torch
6
- from rich import print
7
6
8
7
from trapdata .common .logs import logger
9
8
from trapdata .ml .models .classification import (
@@ -37,6 +36,9 @@ def __init__(
37
36
self .detections = list (detections )
38
37
self .results : list [Detection ] = []
39
38
super ().__init__ (* args , ** kwargs )
39
+ logger .info (
40
+ f"Initialized { self .__class__ .__name__ } with { len (self .detections )} detections"
41
+ )
40
42
41
43
def get_dataset (self ):
42
44
return ClassificationImageDataset (
@@ -89,19 +91,32 @@ def save_results(
89
91
timestamp = datetime .datetime .now (),
90
92
)
91
93
self .update_classification (detection , classification )
92
- print (detection )
94
+ # print(detection)
93
95
self .results .extend (self .detections )
94
96
logger .info (f"Saving { len (self .results )} detections with classifications" )
95
97
return self .results
96
98
97
- def update_classification (self , detection : Detection , new_classification : Classification ) -> None :
99
+ def update_classification (
100
+ self , detection : Detection , new_classification : Classification
101
+ ) -> None :
98
102
# Remove all existing classifications from this algorithm
99
- detection .classifications = [c for c in detection .classifications if c .algorithm != self .name ]
103
+ detection .classifications = [
104
+ c for c in detection .classifications if c .algorithm != self .name
105
+ ]
100
106
# Add the new classification for this algorithm
101
107
detection .classifications .append (new_classification )
108
+ logger .debug (
109
+ f"Updated classification for detection { detection .bbox } . Total classifications: { len (detection .classifications )} "
110
+ )
102
111
103
112
def run (self ) -> list [Detection ]:
113
+ logger .info (
114
+ f"Starting { self .__class__ .__name__ } run with { len (self .results )} detections"
115
+ )
104
116
super ().run ()
117
+ logger .info (
118
+ f"Finished { self .__class__ .__name__ } run. Processed { len (self .results )} detections"
119
+ )
105
120
return self .results
106
121
107
122
@@ -134,8 +149,11 @@ def save_results(
134
149
# Specific to binary classification / the filter model
135
150
terminal = False ,
136
151
)
137
- print (detection )
138
- if not self .filter_results or classification .classification == self .positive_binary_label :
152
+ # print(detection)
153
+ if (
154
+ not self .filter_results
155
+ or classification .classification == self .positive_binary_label
156
+ ):
139
157
self .update_classification (detection , classification )
140
158
141
159
self .results .extend (self .detections )
@@ -149,15 +167,11 @@ class MothClassifierPanama(
149
167
pass
150
168
151
169
152
- class MothClassifierPanama2024 (
153
- MothClassifier , PanamaMothSpeciesClassifier2024
154
- ):
170
+ class MothClassifierPanama2024 (MothClassifier , PanamaMothSpeciesClassifier2024 ):
155
171
pass
156
172
157
173
158
- class MothClassifierUKDenmark (
159
- MothClassifier , UKDenmarkMothSpeciesClassifier2024
160
- ):
174
+ class MothClassifierUKDenmark (MothClassifier , UKDenmarkMothSpeciesClassifier2024 ):
161
175
pass
162
176
163
177
0 commit comments