diff --git a/ragatouille/RAGTrainer.py b/ragatouille/RAGTrainer.py index 15ebbc8..8da5ba6 100644 --- a/ragatouille/RAGTrainer.py +++ b/ragatouille/RAGTrainer.py @@ -95,7 +95,7 @@ def prepare_training_data( data_out_path: Union[str, Path] - Path to the directory where the data has been exported. """ if all_documents is not None: - self.collection += all_documents + self.collection += [doc for doc in all_documents if isinstance(doc, str)] self.data_dir = Path(data_out_path) if len(raw_data[0]) == 2: @@ -107,21 +107,22 @@ def prepare_training_data( else: raise ValueError("Raw data must be a list of pairs or triplets of strings.") - if type(raw_data[0][1]) == str: - self.collection += [x[1] for x in raw_data] - else: - for x in raw_data: - for txt in x[1]: - self.collection.append(txt) - if data_type == "triplets": - if type(raw_data[0][2]) == str: - self.collection += [x[2] for x in raw_data] + self.queries = set() + for x in raw_data: + if isinstance(x[0], str): + self.queries.add(x[0]) else: - for x in raw_data: - for txt in x[2]: - self.collection.append(txt) - - self.queries = set([x[0] for x in raw_data]) + raise ValueError("Queries must be a strings.") + if isinstance(x[1], str): + self.collection.append(x[1]) + elif isinstance(x[1], list): + self.collection += [txt for txt in x[1] if isinstance(txt, str)] + + if len(x) == 3: # For triplets + if isinstance(x[2], str): + self.collection.append(x[2]) + elif isinstance(x[2], list): + self.collection += [txt for txt in x[2] if isinstance(txt, str)] self.collection = list(set(self.collection)) seeded_shuffle(self.collection)