Skip to content

Commit

Permalink
Removing more dead code
Browse files Browse the repository at this point in the history
  • Loading branch information
Thilina Rajapakse committed May 9, 2024
1 parent 093b84d commit 3e2944f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 97 deletions.
2 changes: 2 additions & 0 deletions simpletransformers/retrieval/retrieval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3227,6 +3227,8 @@ def _get_loss(
true_n_scores,
)

# This is where the quartet loss will go

if not (
self.args.include_nll_loss
or self.args.mse_loss
Expand Down
150 changes: 53 additions & 97 deletions simpletransformers/retrieval/retrieval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,41 @@
# faiss.omp_set_num_threads(get_default_process_count())


def add_titles_to_passages(dataset):
if "title" not in dataset.column_names:
raise ValueError(
"The dataset must contain a column named 'title' if args.include_title is True."
)
if "gold_passage" in dataset.column_names:
dataset = dataset.map(
lambda example: {
"gold_passage": example["title"] + " " + example["gold_passage"]
}
)
if (
"passage_text_a" in dataset.column_names
and "passage_text_b" in dataset.column_names
):
if (
"title_a" not in dataset.column_names
or "title_b" not in dataset.column_names
):
raise ValueError(
"The dataset must contain columns named 'title_a' and 'title_b' if args.include_title is True."
)
dataset = dataset.map(
lambda example: {
"passage_text_a": example["title_a"] + " " + example["passage_text_a"]
}
)
dataset = dataset.map(
lambda example: {
"passage_text_b": example["title_b"] + " " + example["passage_text_b"]
}
)
return dataset


def load_hf_dataset(
data,
context_tokenizer,
Expand Down Expand Up @@ -80,29 +115,12 @@ def load_hf_dataset(
)
dataset = dataset["train"]
if args.include_title:
if "title" not in dataset.column_names:
raise ValueError(
"The dataset must contain a column named 'title' if args.include_title is True."
)
dataset = dataset.map(
lambda example: {
"gold_passage": example["title"] + " " + example["gold_passage"]
}
)
dataset = add_titles_to_passages(dataset)

else:
dataset = HFDataset.from_pandas(data)
if args.include_title:
if "title" not in dataset.column_names:
raise ValueError(
"The dataset must contain a column named 'title' if args.include_title is True."
)
dataset = dataset.map(
lambda example: {
"gold_passage": example["title"] + " " + example["gold_passage"]
if example["title"] is not None
else example["gold_passage"]
}
)
dataset = add_titles_to_passages(dataset)

# Assign an id to each unique gold_passage
# passage_dict = {}
Expand Down Expand Up @@ -179,14 +197,6 @@ def load_hf_dataset(
# "passage_id",
]

if args.unified_cross_rr and teacher_tokenizer:
column_names += [
"reranking_context_ids",
"reranking_context_mask",
"reranking_query_ids",
"reranking_query_mask",
]

if args.include_margin_mse_loss and not evaluate:
column_names += ["margin"]

Expand All @@ -209,14 +219,14 @@ def load_hf_dataset(
return dataset, gold_passages
else:
dataset.set_format(type="pt", columns=column_names)
if args.unified_cross_rr and not clustered_training and not evaluate:
dataset = dataset.to_pandas()
dataset = np.array_split(
dataset, math.ceil(len(dataset) / args.train_batch_size)
)
batch_datasets = [HFDataset.from_pandas(df) for df in dataset]
# if args.unified_cross_rr and not clustered_training and not evaluate:
# dataset = dataset.to_pandas()
# dataset = np.array_split(
# dataset, math.ceil(len(dataset) / args.train_batch_size)
# )
# batch_datasets = [HFDataset.from_pandas(df) for df in dataset]

dataset = ClusteredDataset(batch_datasets, len(batch_datasets))
# dataset = ClusteredDataset(batch_datasets, len(batch_datasets))

return dataset

Expand All @@ -230,11 +240,6 @@ def preprocess_batch_for_hf_dataset(
teacher_tokenizer=None,
n_hard_negatives=1,
):
if teacher_tokenizer is None:
unified_rr = False
else:
unified_rr = True

try:
context_inputs = context_tokenizer(
dataset["gold_passage"],
Expand Down Expand Up @@ -272,26 +277,6 @@ def preprocess_batch_for_hf_dataset(
context_mask = context_inputs["attention_mask"].squeeze()
query_mask = query_inputs["attention_mask"].squeeze()

if unified_rr or (args.unified_cross_rr and teacher_tokenizer):
reranking_query_inputs = teacher_tokenizer(
dataset["query_text"],
padding=False,
return_tensors="np",
truncation=True,
)

reranking_context_inputs = teacher_tokenizer(
dataset["gold_passage"],
padding=False,
return_tensors="np",
truncation=True,
)

reranking_context_ids = reranking_context_inputs["input_ids"]
reranking_context_mask = reranking_context_inputs["attention_mask"]
reranking_query_ids = reranking_query_inputs["input_ids"]
reranking_query_mask = reranking_query_inputs["attention_mask"]

if args.cluster_concatenated:
try:
clustering_context_inputs = context_tokenizer(
Expand Down Expand Up @@ -407,24 +392,12 @@ def preprocess_batch_for_hf_dataset(
"clustering_context_mask": clustering_context_mask,
}
else:
if unified_rr:
return {
"context_ids": context_ids,
"query_ids": query_ids,
"context_mask": context_mask,
"query_mask": query_mask,
"reranking_context_ids": reranking_context_ids,
"reranking_context_mask": reranking_context_mask,
"reranking_query_ids": reranking_query_ids,
"reranking_query_mask": reranking_query_mask,
}
else:
return {
"context_ids": context_ids,
"query_ids": query_ids,
"context_mask": context_mask,
"query_mask": query_mask,
}
return {
"context_ids": context_ids,
"query_ids": query_ids,
"context_mask": context_mask,
"query_mask": query_mask,
}


def get_output_embeddings(
Expand Down Expand Up @@ -482,7 +455,6 @@ def embed(
amp=None,
pretokenized=False,
cluster_concatenated=False,
unified_rr=False,
passage_column="passages",
args=None,
autoencoder=None,
Expand Down Expand Up @@ -643,13 +615,7 @@ def embed(
n_cls_tokens=(1 + extra_cls_token_count),
)

if unified_rr:
embeddings = embeddings.detach().cpu().numpy()
rerank_embeddings = embeddings[:, : embeddings.shape[1] // 2]
embeddings = embeddings[:, embeddings.shape[1] // 2 :]
return {"embeddings": embeddings, "rerank_embeddings": rerank_embeddings}
else:
return {"embeddings": embeddings.detach().cpu().numpy()}
return {"embeddings": embeddings.detach().cpu().numpy()}


def add_hard_negatives_to_evaluation_dataset(dataset):
Expand Down Expand Up @@ -709,15 +675,7 @@ def get_evaluation_passage_dataset(
else:
passage_dataset = HFDataset.from_pandas(eval_data)
if args.include_title_in_corpus:
if "title" not in passage_dataset.column_names:
raise ValueError(
"The dataset must contain a column named 'title' if args.include_title_in_corpus is True."
)
passage_dataset = passage_dataset.map(
lambda example: {
"gold_passage": example["title"] + " " + example["gold_passage"]
}
)
dataset = add_titles_to_passages(dataset)

try:
passage_dataset = passage_dataset.remove_columns("query_text")
Expand Down Expand Up @@ -996,7 +954,6 @@ def get_prediction_passage_dataset(
device=device,
fp16=args.fp16,
amp=amp,
unified_rr=args.unified_rr,
args=args,
autoencoder=autoencoder,
),
Expand Down Expand Up @@ -2018,7 +1975,6 @@ def embed_passages_trec_format(
fp16=args.fp16,
amp=amp,
passage_column="passage_text",
unified_rr=args.unified_rr,
args=args,
autoencoder=autoencoder,
),
Expand Down

0 comments on commit 3e2944f

Please sign in to comment.