-
Notifications
You must be signed in to change notification settings - Fork 1
/
ranking_losses.py
192 lines (143 loc) · 7.79 KB
/
ranking_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import torch
import numpy as np
from itertools import product
from torch.nn import BCEWithLogitsLoss
DEFAULT_EPS = 1e-10
PADDED_Y_VALUE = -1
####################################################################################################
# Factory function to obtain the ranking loss
####################################################################################################
def get_ranking_loss(name):
if name.lower() == "listwise-weighted":
return listMLE_weighted
if name.lower() == "listwise":
return listMLE
if name.lower() == "pairwise":
return rankNet
if name.lower() == "pointwise":
return pointwise_rmse # Also referred to as Subset Regression
return None
####################################################################################################
# Actual loss functions
####################################################################################################
# The following loss implementations are adopted from
# https://github.com/allegro/allRank/blob/master/allrank/models/losses
# The listMLE_weighted is an extended version of listMLE with weighting added to it.
def listMLE_weighted(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
"""
ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
# shuffle for randomised tie resolution
random_indices = torch.randperm(y_pred.shape[-1])
y_pred_shuffled = y_pred[:, random_indices]
y_true_shuffled = y_true[:, random_indices]
y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1)
mask = y_true_sorted == padded_value_indicator
preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices)
preds_sorted_by_true[mask] = float("-inf")
max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True)
preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values
cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1])
observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max
observation_loss[mask] = 0.0
####### Weighting extension
# Weighted ranking because it is more important to get the the first ranks right than the rest.
weight = np.log(np.arange(observation_loss.shape[-1]) + 2) # Adding 2 to prevent using log(0) & log(1) as weights.
weight = np.array(weight, dtype=np.float32)
weight = torch.from_numpy(weight)[None, :]
observation_loss = observation_loss / weight
#######
return torch.mean(torch.sum(observation_loss, dim=1))
def listMLE(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
"""
ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
# shuffle for randomised tie resolution
random_indices = torch.randperm(y_pred.shape[-1])
y_pred_shuffled = y_pred[:, random_indices]
y_true_shuffled = y_true[:, random_indices]
y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1)
mask = y_true_sorted == padded_value_indicator
preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices)
preds_sorted_by_true[mask] = float("-inf")
max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True)
preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values
cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1])
observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max
observation_loss[mask] = 0.0
return torch.mean(torch.sum(observation_loss, dim=1))
def rankNet(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, weight_by_diff=False, weight_by_diff_powed=False):
"""
RankNet loss introduced in "Learning to Rank using Gradient Descent".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param weight_by_diff: flag indicating whether to weight the score differences by ground truth differences.
:param weight_by_diff_powed: flag indicating whether to weight the score differences by the squared ground truth differences.
:return: loss value, a torch.Tensor
"""
y_pred = y_pred.clone()
y_true = y_true.clone()
mask = y_true == padded_value_indicator
y_pred[mask] = float('-inf')
y_true[mask] = float('-inf')
# here we generate every pair of indices from the range of document length in the batch
document_pairs_candidates = list(product(range(y_true.shape[1]), repeat=2))
pairs_true = y_true[:, document_pairs_candidates]
selected_pred = y_pred[:, document_pairs_candidates]
# here we calculate the relative true relevance of every candidate pair
true_diffs = pairs_true[:, :, 0] - pairs_true[:, :, 1]
pred_diffs = selected_pred[:, :, 0] - selected_pred[:, :, 1]
# here we filter just the pairs that are 'positive' and did not involve a padded instance
# we can do that since in the candidate pairs we had symetric pairs so we can stick with
# positive ones for a simpler loss function formulation
the_mask = (true_diffs > 0) & (~torch.isinf(true_diffs))
pred_diffs = pred_diffs[the_mask]
weight = None
if weight_by_diff:
abs_diff = torch.abs(true_diffs)
weight = abs_diff[the_mask]
elif weight_by_diff_powed:
true_pow_diffs = torch.pow(pairs_true[:, :, 0], 2) - torch.pow(pairs_true[:, :, 1], 2)
abs_diff = torch.abs(true_pow_diffs)
weight = abs_diff[the_mask]
# here we 'binarize' true relevancy diffs since for a pairwise loss we just need to know
# whether one document is better than the other and not about the actual difference in
# their relevancy levels
true_diffs = (true_diffs > 0).type(torch.float32)
true_diffs = true_diffs[the_mask]
return BCEWithLogitsLoss(weight=weight)(pred_diffs, true_diffs)
def pointwise_rmse(y_pred, y_true, no_of_levels=None, padded_value_indicator=PADDED_Y_VALUE):
"""
Pointwise RMSE loss.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param no_of_levels: number of unique ground truth values
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
####### This section of code is an extension tailor made for our purpose
if no_of_levels is None:
# Assuming that all values in the y_true are distinct.
no_of_levels = y_true.shape[-1]
#######
y_pred = y_pred.clone()
y_true = y_true.clone()
mask = y_true == padded_value_indicator
valid_mask = (y_true != padded_value_indicator).type(torch.float32)
y_true[mask] = 0
y_pred[mask] = 0
errors = (y_true - no_of_levels * y_pred)
squared_errors = errors ** 2
mean_squared_errors = torch.sum(squared_errors, dim=1) / torch.sum(valid_mask, dim=1)
rmses = torch.sqrt(mean_squared_errors)
return torch.mean(rmses)