From 3e2944fb264dd62ad315841fad73603eafe4d4e0 Mon Sep 17 00:00:00 2001 From: Thilina Rajapakse Date: Fri, 10 May 2024 00:59:17 +0200 Subject: [PATCH] Removing more dead code --- .../retrieval/retrieval_model.py | 2 + .../retrieval/retrieval_utils.py | 150 +++++++----------- 2 files changed, 55 insertions(+), 97 deletions(-) diff --git a/simpletransformers/retrieval/retrieval_model.py b/simpletransformers/retrieval/retrieval_model.py index 57172d34..c0d9f3a7 100644 --- a/simpletransformers/retrieval/retrieval_model.py +++ b/simpletransformers/retrieval/retrieval_model.py @@ -3227,6 +3227,8 @@ def _get_loss( true_n_scores, ) + # This is where the quartet loss will go + if not ( self.args.include_nll_loss or self.args.mse_loss diff --git a/simpletransformers/retrieval/retrieval_utils.py b/simpletransformers/retrieval/retrieval_utils.py index 61643590..920b49de 100644 --- a/simpletransformers/retrieval/retrieval_utils.py +++ b/simpletransformers/retrieval/retrieval_utils.py @@ -37,6 +37,41 @@ # faiss.omp_set_num_threads(get_default_process_count()) +def add_titles_to_passages(dataset): + if "title" not in dataset.column_names: + raise ValueError( + "The dataset must contain a column named 'title' if args.include_title is True." + ) + if "gold_passage" in dataset.column_names: + dataset = dataset.map( + lambda example: { + "gold_passage": example["title"] + " " + example["gold_passage"] + } + ) + if ( + "passage_text_a" in dataset.column_names + and "passage_text_b" in dataset.column_names + ): + if ( + "title_a" not in dataset.column_names + or "title_b" not in dataset.column_names + ): + raise ValueError( + "The dataset must contain columns named 'title_a' and 'title_b' if args.include_title is True." + ) + dataset = dataset.map( + lambda example: { + "passage_text_a": example["title_a"] + " " + example["passage_text_a"] + } + ) + dataset = dataset.map( + lambda example: { + "passage_text_b": example["title_b"] + " " + example["passage_text_b"] + } + ) + return dataset + + def load_hf_dataset( data, context_tokenizer, @@ -80,29 +115,12 @@ def load_hf_dataset( ) dataset = dataset["train"] if args.include_title: - if "title" not in dataset.column_names: - raise ValueError( - "The dataset must contain a column named 'title' if args.include_title is True." - ) - dataset = dataset.map( - lambda example: { - "gold_passage": example["title"] + " " + example["gold_passage"] - } - ) + dataset = add_titles_to_passages(dataset) + else: dataset = HFDataset.from_pandas(data) if args.include_title: - if "title" not in dataset.column_names: - raise ValueError( - "The dataset must contain a column named 'title' if args.include_title is True." - ) - dataset = dataset.map( - lambda example: { - "gold_passage": example["title"] + " " + example["gold_passage"] - if example["title"] is not None - else example["gold_passage"] - } - ) + dataset = add_titles_to_passages(dataset) # Assign an id to each unique gold_passage # passage_dict = {} @@ -179,14 +197,6 @@ def load_hf_dataset( # "passage_id", ] - if args.unified_cross_rr and teacher_tokenizer: - column_names += [ - "reranking_context_ids", - "reranking_context_mask", - "reranking_query_ids", - "reranking_query_mask", - ] - if args.include_margin_mse_loss and not evaluate: column_names += ["margin"] @@ -209,14 +219,14 @@ def load_hf_dataset( return dataset, gold_passages else: dataset.set_format(type="pt", columns=column_names) - if args.unified_cross_rr and not clustered_training and not evaluate: - dataset = dataset.to_pandas() - dataset = np.array_split( - dataset, math.ceil(len(dataset) / args.train_batch_size) - ) - batch_datasets = [HFDataset.from_pandas(df) for df in dataset] + # if args.unified_cross_rr and not clustered_training and not evaluate: + # dataset = dataset.to_pandas() + # dataset = np.array_split( + # dataset, math.ceil(len(dataset) / args.train_batch_size) + # ) + # batch_datasets = [HFDataset.from_pandas(df) for df in dataset] - dataset = ClusteredDataset(batch_datasets, len(batch_datasets)) + # dataset = ClusteredDataset(batch_datasets, len(batch_datasets)) return dataset @@ -230,11 +240,6 @@ def preprocess_batch_for_hf_dataset( teacher_tokenizer=None, n_hard_negatives=1, ): - if teacher_tokenizer is None: - unified_rr = False - else: - unified_rr = True - try: context_inputs = context_tokenizer( dataset["gold_passage"], @@ -272,26 +277,6 @@ def preprocess_batch_for_hf_dataset( context_mask = context_inputs["attention_mask"].squeeze() query_mask = query_inputs["attention_mask"].squeeze() - if unified_rr or (args.unified_cross_rr and teacher_tokenizer): - reranking_query_inputs = teacher_tokenizer( - dataset["query_text"], - padding=False, - return_tensors="np", - truncation=True, - ) - - reranking_context_inputs = teacher_tokenizer( - dataset["gold_passage"], - padding=False, - return_tensors="np", - truncation=True, - ) - - reranking_context_ids = reranking_context_inputs["input_ids"] - reranking_context_mask = reranking_context_inputs["attention_mask"] - reranking_query_ids = reranking_query_inputs["input_ids"] - reranking_query_mask = reranking_query_inputs["attention_mask"] - if args.cluster_concatenated: try: clustering_context_inputs = context_tokenizer( @@ -407,24 +392,12 @@ def preprocess_batch_for_hf_dataset( "clustering_context_mask": clustering_context_mask, } else: - if unified_rr: - return { - "context_ids": context_ids, - "query_ids": query_ids, - "context_mask": context_mask, - "query_mask": query_mask, - "reranking_context_ids": reranking_context_ids, - "reranking_context_mask": reranking_context_mask, - "reranking_query_ids": reranking_query_ids, - "reranking_query_mask": reranking_query_mask, - } - else: - return { - "context_ids": context_ids, - "query_ids": query_ids, - "context_mask": context_mask, - "query_mask": query_mask, - } + return { + "context_ids": context_ids, + "query_ids": query_ids, + "context_mask": context_mask, + "query_mask": query_mask, + } def get_output_embeddings( @@ -482,7 +455,6 @@ def embed( amp=None, pretokenized=False, cluster_concatenated=False, - unified_rr=False, passage_column="passages", args=None, autoencoder=None, @@ -643,13 +615,7 @@ def embed( n_cls_tokens=(1 + extra_cls_token_count), ) - if unified_rr: - embeddings = embeddings.detach().cpu().numpy() - rerank_embeddings = embeddings[:, : embeddings.shape[1] // 2] - embeddings = embeddings[:, embeddings.shape[1] // 2 :] - return {"embeddings": embeddings, "rerank_embeddings": rerank_embeddings} - else: - return {"embeddings": embeddings.detach().cpu().numpy()} + return {"embeddings": embeddings.detach().cpu().numpy()} def add_hard_negatives_to_evaluation_dataset(dataset): @@ -709,15 +675,7 @@ def get_evaluation_passage_dataset( else: passage_dataset = HFDataset.from_pandas(eval_data) if args.include_title_in_corpus: - if "title" not in passage_dataset.column_names: - raise ValueError( - "The dataset must contain a column named 'title' if args.include_title_in_corpus is True." - ) - passage_dataset = passage_dataset.map( - lambda example: { - "gold_passage": example["title"] + " " + example["gold_passage"] - } - ) + dataset = add_titles_to_passages(dataset) try: passage_dataset = passage_dataset.remove_columns("query_text") @@ -996,7 +954,6 @@ def get_prediction_passage_dataset( device=device, fp16=args.fp16, amp=amp, - unified_rr=args.unified_rr, args=args, autoencoder=autoencoder, ), @@ -2018,7 +1975,6 @@ def embed_passages_trec_format( fp16=args.fp16, amp=amp, passage_column="passage_text", - unified_rr=args.unified_rr, args=args, autoencoder=autoencoder, ),