-
Notifications
You must be signed in to change notification settings - Fork 0
/
classification.py
78 lines (63 loc) · 2.66 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torchmetrics import Metric
from torchmetrics.classification import MultilabelAveragePrecision
class TopKAccuracy(Metric):
def __init__(self, topk=1, include_nocalls=False, threshold=0.5, **kwargs):
super().__init__(**kwargs)
self.topk = topk
self.include_nocalls = include_nocalls
self.threshold = threshold
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds, targets):
# Get the top-k predictions
_, topk_pred_indices = preds.topk(self.topk, dim=1, largest=True, sorted=True)
targets = targets.to(preds.device)
no_call_targets = targets.sum(dim=1) == 0
# consider no_call instances (a threshold is needed here!)
if self.include_nocalls:
# check if top-k predictions for all-negative instances are less than threshold
no_positive_predictions = preds.topk(self.topk, dim=1, largest=True).values < self.threshold
correct_all_negative = (no_call_targets & no_positive_predictions.all(dim=1))
else:
# no_calls are removed, set to 0
correct_all_negative = torch.tensor(0).to(targets.device)
# convert one-hot encoded targets to class indices for positive cases
expanded_targets = targets.unsqueeze(1).expand(-1, self.topk, -1)
correct_positive = expanded_targets.gather(2, topk_pred_indices.unsqueeze(-1)).any(dim=1)
# update correct and total, excluding all-negative instances if specified
self.correct += correct_positive.sum() + correct_all_negative.sum()
if not self.include_nocalls:
self.total += targets.size(0) - no_call_targets.sum()
else:
self.total += targets.size(0)
def compute(self):
return self.correct.float() / self.total
class cmAP(MultilabelAveragePrecision):
def __init__(
self,
num_labels,
thresholds=None
):
super().__init__(
num_labels=num_labels,
average="macro",
thresholds=thresholds
)
def __call__(self, logits, labels):
macro_cmap = super().__call__(logits, labels)
return macro_cmap
class mAP(MultilabelAveragePrecision):
def __init__(
self,
num_labels,
thresholds=None
):
super().__init__(
num_labels=num_labels,
average="micro",
thresholds=thresholds
)
def __call__(self, logits, labels):
micro_cmap = super().__call__(logits, labels)
return micro_cmap