From 671241c990649a39cd6b749add111160927c8045 Mon Sep 17 00:00:00 2001 From: Josh Purtell Date: Tue, 27 Feb 2024 06:14:43 -0500 Subject: [PATCH] Feature/safe random negatives (#161) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bugfix * updated RAGTrainer data handling * remove unnec changes * add back 02 example notebook * only select as many negatives as you can * remove old changes * remove unnecessary changes --------- Co-authored-by: Benjamin ClaviƩ --- ragatouille/data/training_data_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ragatouille/data/training_data_processor.py b/ragatouille/data/training_data_processor.py index 39fb19a..ab9aaac 100644 --- a/ragatouille/data/training_data_processor.py +++ b/ragatouille/data/training_data_processor.py @@ -125,7 +125,7 @@ def _get_new_negatives(self, query, passages, mine_hard_negatives, n_new_negativ else: new_negatives = [ x - for x in random.sample(self.collection, n_new_negatives) + for x in random.sample(self.collection, min(n_new_negatives, len(self.collection))) if x not in passages["positives"] and x not in passages["negatives"] ]