Skip to content

Commit c846013

Browse files
TF-Agents Teamcopybara-github
TF-Agents Team
authored andcommitted
Add FIXED_BIAS_WEIGHTS positional bias type to ranking_agent
`FIXED_BIAS_WEIGHTS` is a new type where fixed positional bias weights for all slots are given as an array. These weights can be from user's pre-knowledge, learned from offline analysis (e.g. TopN randomization) and so on. PiperOrigin-RevId: 617544784 Change-Id: I3cd891bffdc65626d3e0ba513f050827f74ce62a
1 parent c2fcf32 commit c846013

File tree

3 files changed

+96
-22
lines changed

3 files changed

+96
-22
lines changed

tf_agents/bandits/agents/examples/v2/train_eval_ranking.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,19 @@
6464
'bias_type',
6565
'',
6666
'Whether the agent models the positional '
67-
'bias with the basis or the exponent changes. If unset, the'
67+
'bias with the basis, the exponent or fixed bias weights. If unset, the'
6868
' agent applies no positional bias.',
6969
)
7070
flags.DEFINE_float(
7171
'bias_severity', 1.0, 'The severity of the bias adjustment by the agent.'
7272
)
73+
flags.DEFINE_list(
74+
'bias_weights',
75+
[],
76+
'The positional bias weights. For FIXED_BIAS_WEIGHTS type, the agent will'
77+
' use these weights to adjust the rewards. The length of the list must be'
78+
' equal to the number of slots.',
79+
)
7380
flags.DEFINE_bool(
7481
'bias_positive_only',
7582
False,
@@ -174,12 +181,15 @@ def _relevance_fn(global_obs, item_obs):
174181
positional_bias_type = ranking_agent.PositionalBiasType.BASE
175182
elif FLAGS.positional_bias_type == 'exponent':
176183
positional_bias_type = ranking_agent.PositionalBiasType.EXPONENT
184+
elif FLAGS.positional_bias_type == 'fixed_bias_weights':
185+
positional_bias_type = ranking_agent.PositionalBiasType.FIXED_BIAS_WEIGHTS
177186
else:
178187
raise NotImplementedError(
179188
'Positional bias type {} is not implemented'.format(
180189
FLAGS.positional_bias_type
181190
)
182191
)
192+
positional_bias_weights = [float(w) for w in FLAGS.positional_bias_weights]
183193

184194
agent = ranking_agent.RankingAgent(
185195
time_step_spec=environment.time_step_spec(),
@@ -190,6 +200,7 @@ def _relevance_fn(global_obs, item_obs):
190200
feedback_model=feedback_model,
191201
positional_bias_type=positional_bias_type,
192202
positional_bias_severity=FLAGS.bias_severity,
203+
positional_bias_weights=positional_bias_weights,
193204
positional_bias_positive_only=FLAGS.bias_positive_only,
194205
summarize_grads_and_vars=True,
195206
)

tf_agents/bandits/agents/ranking_agent.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@
3838
recommendation. The user is responsible for converting the observation to the
3939
syntax required by the agent.
4040
"""
41+
4142
import enum
42-
from typing import Optional, Text
43+
from typing import List, Optional, Text
4344

4445
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
4546
from tf_agents.agents import tf_agent
@@ -127,6 +128,9 @@ class PositionalBiasType(enum.Enum):
127128
# et al. `Correcting for Selection Bias in Learning-to-rank Systems`
128129
# (WWW 2020).
129130
EXPONENT = 2
131+
# The bias weight for each slot position is `bias_weights[k]`, where
132+
# `bias_weights` is the given bias weight array and `k` is the position.
133+
FIXED_BIAS_WEIGHTS = 3
130134

131135

132136
class RankingAgent(tf_agent.TFAgent):
@@ -144,6 +148,7 @@ def __init__(
144148
non_click_score: Optional[float] = None,
145149
positional_bias_type: PositionalBiasType = PositionalBiasType.UNSET,
146150
positional_bias_severity: Optional[float] = None,
151+
positional_bias_weights: Optional[List[float]] = None,
147152
positional_bias_positive_only: bool = False,
148153
logits_temperature: float = 1.0,
149154
summarize_grads_and_vars: bool = False,
@@ -178,6 +183,8 @@ def __init__(
178183
positional_bias_type: Type of positional bias to use when training.
179184
positional_bias_severity: (float) The severity `s`, used for the `BASE`
180185
positional bias type.
186+
positional_bias_weights: (float array) The positional bias weight for each
187+
slot position.
181188
positional_bias_positive_only: Whether to use the above defined bias
182189
weights only for positives (that is, clicked items). If
183190
`positional_bias_type` is unset, this parameter has no effect.
@@ -230,6 +237,22 @@ def __init__(
230237
)
231238
self._positional_bias_type = positional_bias_type
232239
self._positional_bias_severity = positional_bias_severity
240+
# Validate positional_bias_weights for FIXED_BIAS_WEIGHTS PositionalBiasType
241+
if self._positional_bias_type == PositionalBiasType.FIXED_BIAS_WEIGHTS:
242+
if positional_bias_weights is None:
243+
raise ValueError(
244+
'positional_bias_weights is None but should never be for'
245+
' FIXED_BIAS_WEIGHTS PositionalBiasType.'
246+
)
247+
elif len(positional_bias_weights) != self._num_slots:
248+
raise ValueError(
249+
'The length of positional_bias_weights should be the same as the'
250+
' number of slots. The length of positional_bias_weights is {} and'
251+
' the number of slots is {}.'.format(
252+
len(positional_bias_weights), self._num_slots
253+
)
254+
)
255+
self._positional_bias_weights = positional_bias_weights
233256
self._positional_bias_positive_only = positional_bias_positive_only
234257
if policy_type == RankingPolicyType.UNKNOWN:
235258
policy_type = RankingPolicyType.COSINE_DISTANCE
@@ -409,19 +432,27 @@ def _construct_sample_weights(self, reward, observation, weights):
409432
chosen_index + 1, self._num_slots, dtype=tf.float32
410433
)
411434
weights = multiplier * weights
412-
if self._positional_bias_type != PositionalBiasType.UNSET:
413-
batched_range = tf.broadcast_to(
414-
tf.range(self._num_slots, dtype=tf.float32), tf.shape(weights)
435+
436+
if self._positional_bias_type == PositionalBiasType.UNSET:
437+
return weights
438+
439+
batched_range = tf.broadcast_to(
440+
tf.range(self._num_slots, dtype=tf.float32), tf.shape(weights)
441+
)
442+
if self._positional_bias_type == PositionalBiasType.BASE:
443+
position_bias_multipliers = tf.pow(
444+
batched_range + 1, self._positional_bias_severity
415445
)
416-
if self._positional_bias_type == PositionalBiasType.BASE:
417-
position_bias_multipliers = tf.pow(
418-
batched_range + 1, self._positional_bias_severity
419-
)
420-
elif self._positional_bias_type == PositionalBiasType.EXPONENT:
421-
position_bias_multipliers = tf.pow(
422-
self._positional_bias_severity, batched_range
423-
)
424-
else:
425-
raise ValueError('non-existing positional bias type')
426-
weights = position_bias_multipliers * weights
446+
elif self._positional_bias_type == PositionalBiasType.EXPONENT:
447+
position_bias_multipliers = tf.pow(
448+
self._positional_bias_severity, batched_range
449+
)
450+
elif self._positional_bias_type == PositionalBiasType.FIXED_BIAS_WEIGHTS:
451+
position_bias_multipliers = tf.tile(
452+
tf.expand_dims(self._positional_bias_weights, axis=0),
453+
[batch_size, 1],
454+
)
455+
else:
456+
raise ValueError('non-existing positional bias type')
457+
weights = position_bias_multipliers * weights
427458
return weights

tf_agents/bandits/agents/ranking_agent_test.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ def testTrainAgentScoreFeedback(
311311
'positional_bias_type': ranking_agent.PositionalBiasType.BASE,
312312
'positional_bias_severity': 1.2,
313313
'positional_bias_positive_only': False,
314+
'positional_bias_weights': None,
315+
'expected_second_weight': 2.2974, # 2**positional_bias_severity
314316
},
315317
{
316318
'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR,
@@ -323,6 +325,8 @@ def testTrainAgentScoreFeedback(
323325
'positional_bias_type': ranking_agent.PositionalBiasType.EXPONENT,
324326
'positional_bias_severity': 1.3,
325327
'positional_bias_positive_only': False,
328+
'positional_bias_weights': None,
329+
'expected_second_weight': 1.3, # positional_bias_severity
326330
},
327331
{
328332
'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR,
@@ -335,6 +339,36 @@ def testTrainAgentScoreFeedback(
335339
'positional_bias_type': ranking_agent.PositionalBiasType.BASE,
336340
'positional_bias_severity': 1.0,
337341
'positional_bias_positive_only': True,
342+
'positional_bias_weights': None,
343+
'expected_second_weight': 2.0, # 2**positional_bias_severity
344+
},
345+
{
346+
'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR,
347+
'policy_type': ranking_agent.RankingPolicyType.DESCENDING_SCORES,
348+
'batch_size': 2,
349+
'global_dim': 3,
350+
'item_dim': 4,
351+
'num_items': 13,
352+
'num_slots': 11,
353+
'positional_bias_type': (
354+
ranking_agent.PositionalBiasType.FIXED_BIAS_WEIGHTS
355+
),
356+
'positional_bias_severity': None,
357+
'positional_bias_positive_only': True,
358+
'positional_bias_weights': [
359+
0.1,
360+
0.2,
361+
0.3,
362+
0.4,
363+
0.5,
364+
0.6,
365+
0.7,
366+
0.8,
367+
0.9,
368+
1.0,
369+
1.1,
370+
],
371+
'expected_second_weight': 0.2, # positional_bias_weights[1]
338372
},
339373
])
340374
def testPositionalBiasParams(
@@ -349,6 +383,8 @@ def testPositionalBiasParams(
349383
positional_bias_type,
350384
positional_bias_severity,
351385
positional_bias_positive_only,
386+
positional_bias_weights,
387+
expected_second_weight,
352388
):
353389
if not tf.executing_eagerly():
354390
self.skipTest('Only works in eager mode.')
@@ -386,6 +422,7 @@ def testPositionalBiasParams(
386422
positional_bias_type=positional_bias_type,
387423
positional_bias_severity=positional_bias_severity,
388424
positional_bias_positive_only=positional_bias_positive_only,
425+
positional_bias_weights=positional_bias_weights,
389426
optimizer=optimizer,
390427
)
391428
global_obs = tf.reshape(
@@ -426,12 +463,7 @@ def testPositionalBiasParams(
426463
agent.train(experience)
427464
weights = agent._construct_sample_weights(scores, observations, None)
428465
self.assertAllEqual(weights.shape, [batch_size, num_slots])
429-
expected = (
430-
2**positional_bias_severity
431-
if positional_bias_type == ranking_agent.PositionalBiasType.BASE
432-
else positional_bias_severity
433-
)
434-
self.assertAllClose(weights[-1, 1], expected)
466+
self.assertAllClose(weights[-1, 1], expected_second_weight)
435467

436468

437469
if __name__ == '__main__':

0 commit comments

Comments
 (0)