-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathranksim.py
40 lines (32 loc) · 1.36 KB
/
ranksim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Copyright (c) 2023-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import random
import torch.nn.functional as F
from ranking import TrueRanker, rank_normalised
def batchwise_ranking_regularizer(features, targets, lambda_val):
loss = 0
# Reduce ties and boost relative representation of infrequent labels by computing the
# regularizer over a subset of the batch in which each label appears at most once
batch_unique_targets = torch.unique(targets)
if len(batch_unique_targets) < len(targets):
sampled_indices = []
for target in batch_unique_targets:
sampled_indices.append(random.choice((targets == target).nonzero()[:,0]).item())
x = features[sampled_indices]
y = targets[sampled_indices]
else:
x = features
y = targets
# Compute feature similarities
xxt = torch.matmul(F.normalize(x.view(x.size(0),-1)), F.normalize(x.view(x.size(0),-1)).permute(1,0))
# Compute ranking similarity loss
for i in range(len(y)):
label_ranks = rank_normalised(-torch.abs(y[i] - y).transpose(0,1))
feature_ranks = TrueRanker.apply(xxt[i].unsqueeze(dim=0), lambda_val)
loss += F.mse_loss(feature_ranks, label_ranks)
return loss