Skip to content

Commit 8d0a733

Browse files
mihowDebian
authored andcommitted
fix: remove duplicate classifications if save is called twice
1 parent 292c139 commit 8d0a733

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

trapdata/api/models/classification.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,18 @@ def save_results(
8888
algorithm=self.name,
8989
timestamp=datetime.datetime.now(),
9090
)
91-
detection.classifications.append(classification)
91+
self.update_classification(detection, classification)
9292
print(detection)
9393
self.results.extend(self.detections)
9494
logger.info(f"Saving {len(self.results)} detections with classifications")
9595
return self.results
9696

97+
def update_classification(self, detection: Detection, new_classification: Classification) -> None:
98+
# Remove all existing classifications from this algorithm
99+
detection.classifications = [c for c in detection.classifications if c.algorithm != self.name]
100+
# Add the new classification for this algorithm
101+
detection.classifications.append(new_classification)
102+
97103
def run(self) -> list[Detection]:
98104
super().run()
99105
return self.results
@@ -129,11 +135,8 @@ def save_results(
129135
terminal=False,
130136
)
131137
print(detection)
132-
if self.filter_results:
133-
if classification.classification == self.positive_binary_label:
134-
detection.classifications.append(classification)
135-
else:
136-
detection.classifications.append(classification)
138+
if not self.filter_results or classification.classification == self.positive_binary_label:
139+
self.update_classification(detection, classification)
137140

138141
self.results.extend(self.detections)
139142
logger.info(f"Saving {len(self.results)} detections with classifications")

0 commit comments

Comments
 (0)