38
38
recommendation. The user is responsible for converting the observation to the
39
39
syntax required by the agent.
40
40
"""
41
+
41
42
import enum
42
- from typing import Optional , Text
43
+ from typing import List , Optional , Text
43
44
44
45
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
45
46
from tf_agents .agents import tf_agent
@@ -127,6 +128,9 @@ class PositionalBiasType(enum.Enum):
127
128
# et al. `Correcting for Selection Bias in Learning-to-rank Systems`
128
129
# (WWW 2020).
129
130
EXPONENT = 2
131
+ # The bias weight for each slot position is `bias_weights[k]`, where
132
+ # `bias_weights` is the given bias weight array and `k` is the position.
133
+ FIXED_BIAS_WEIGHTS = 3
130
134
131
135
132
136
class RankingAgent (tf_agent .TFAgent ):
@@ -144,6 +148,7 @@ def __init__(
144
148
non_click_score : Optional [float ] = None ,
145
149
positional_bias_type : PositionalBiasType = PositionalBiasType .UNSET ,
146
150
positional_bias_severity : Optional [float ] = None ,
151
+ positional_bias_weights : Optional [List [float ]] = None ,
147
152
positional_bias_positive_only : bool = False ,
148
153
logits_temperature : float = 1.0 ,
149
154
summarize_grads_and_vars : bool = False ,
@@ -178,6 +183,8 @@ def __init__(
178
183
positional_bias_type: Type of positional bias to use when training.
179
184
positional_bias_severity: (float) The severity `s`, used for the `BASE`
180
185
positional bias type.
186
+ positional_bias_weights: (float array) The positional bias weight for each
187
+ slot position.
181
188
positional_bias_positive_only: Whether to use the above defined bias
182
189
weights only for positives (that is, clicked items). If
183
190
`positional_bias_type` is unset, this parameter has no effect.
@@ -230,6 +237,22 @@ def __init__(
230
237
)
231
238
self ._positional_bias_type = positional_bias_type
232
239
self ._positional_bias_severity = positional_bias_severity
240
+ # Validate positional_bias_weights for FIXED_BIAS_WEIGHTS PositionalBiasType
241
+ if self ._positional_bias_type == PositionalBiasType .FIXED_BIAS_WEIGHTS :
242
+ if positional_bias_weights is None :
243
+ raise ValueError (
244
+ 'positional_bias_weights is None but should never be for'
245
+ ' FIXED_BIAS_WEIGHTS PositionalBiasType.'
246
+ )
247
+ elif len (positional_bias_weights ) != self ._num_slots :
248
+ raise ValueError (
249
+ 'The length of positional_bias_weights should be the same as the'
250
+ ' number of slots. The length of positional_bias_weights is {} and'
251
+ ' the number of slots is {}.' .format (
252
+ len (positional_bias_weights ), self ._num_slots
253
+ )
254
+ )
255
+ self ._positional_bias_weights = positional_bias_weights
233
256
self ._positional_bias_positive_only = positional_bias_positive_only
234
257
if policy_type == RankingPolicyType .UNKNOWN :
235
258
policy_type = RankingPolicyType .COSINE_DISTANCE
@@ -409,19 +432,27 @@ def _construct_sample_weights(self, reward, observation, weights):
409
432
chosen_index + 1 , self ._num_slots , dtype = tf .float32
410
433
)
411
434
weights = multiplier * weights
412
- if self ._positional_bias_type != PositionalBiasType .UNSET :
413
- batched_range = tf .broadcast_to (
414
- tf .range (self ._num_slots , dtype = tf .float32 ), tf .shape (weights )
435
+
436
+ if self ._positional_bias_type == PositionalBiasType .UNSET :
437
+ return weights
438
+
439
+ batched_range = tf .broadcast_to (
440
+ tf .range (self ._num_slots , dtype = tf .float32 ), tf .shape (weights )
441
+ )
442
+ if self ._positional_bias_type == PositionalBiasType .BASE :
443
+ position_bias_multipliers = tf .pow (
444
+ batched_range + 1 , self ._positional_bias_severity
415
445
)
416
- if self ._positional_bias_type == PositionalBiasType .BASE :
417
- position_bias_multipliers = tf .pow (
418
- batched_range + 1 , self ._positional_bias_severity
419
- )
420
- elif self ._positional_bias_type == PositionalBiasType .EXPONENT :
421
- position_bias_multipliers = tf .pow (
422
- self ._positional_bias_severity , batched_range
423
- )
424
- else :
425
- raise ValueError ('non-existing positional bias type' )
426
- weights = position_bias_multipliers * weights
446
+ elif self ._positional_bias_type == PositionalBiasType .EXPONENT :
447
+ position_bias_multipliers = tf .pow (
448
+ self ._positional_bias_severity , batched_range
449
+ )
450
+ elif self ._positional_bias_type == PositionalBiasType .FIXED_BIAS_WEIGHTS :
451
+ position_bias_multipliers = tf .tile (
452
+ tf .expand_dims (self ._positional_bias_weights , axis = 0 ),
453
+ [batch_size , 1 ],
454
+ )
455
+ else :
456
+ raise ValueError ('non-existing positional bias type' )
457
+ weights = position_bias_multipliers * weights
427
458
return weights
0 commit comments