diff --git a/tf_agents/bandits/agents/ranking_agent_test.py b/tf_agents/bandits/agents/ranking_agent_test.py index 421d59d7f..435d77d40 100644 --- a/tf_agents/bandits/agents/ranking_agent_test.py +++ b/tf_agents/bandits/agents/ranking_agent_test.py @@ -463,7 +463,9 @@ def testPositionalBiasParams( agent.train(experience) weights = agent._construct_sample_weights(scores, observations, None) self.assertAllEqual(weights.shape, [batch_size, num_slots]) - self.assertAllClose(weights[-1, 1], expected_second_weight) + self.assertAllClose( + weights[-1, 1], expected_second_weight, atol=1e-3, rtol=1e-3 + ) if __name__ == '__main__':