Skip to content

Commit

Permalink
add target normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Jul 17, 2024
1 parent 84785ce commit b6eba85
Showing 1 changed file with 25 additions and 56 deletions.
81 changes: 25 additions & 56 deletions lightning_ir/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from itertools import islice
from pathlib import Path
from typing import Dict, Iterator, Literal, Tuple
from typing import Any, Dict, Iterator, Literal, Tuple

import ir_datasets
import numpy as np
Expand Down Expand Up @@ -33,18 +33,13 @@ def __init__(self, dataset: str) -> None:

@property
def DASHED_DATASET_MAP(self) -> Dict[str, str]:
return {
dataset.replace("/", "-"): dataset
for dataset in ir_datasets.registry._registered
}
return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}

@property
def queries(self) -> pd.Series:
if self._queries is None:
if self.ir_dataset is None:
raise ValueError(
f"Unable to find dataset {self.dataset} in ir-datasets"
)
raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
queries_iter = self.ir_dataset.queries_iter()
self._queries = pd.Series(
{query.query_id: query.default_text() for query in queries_iter},
Expand All @@ -57,9 +52,7 @@ def queries(self) -> pd.Series:
def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]:
if self._docs is None:
if self.ir_dataset is None:
raise ValueError(
f"Unable to find dataset {self.dataset} in ir-datasets"
)
raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
self._docs = self.ir_dataset.docs_store()
return self._docs

Expand All @@ -69,9 +62,7 @@ def qrels(self) -> pd.DataFrame | None:
return self._qrels
if self.ir_dataset is None:
return None
qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename(
{"subtopic_id": "iteration"}, axis=1
)
qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1)
if "iteration" not in qrels.columns:
qrels["iteration"] = 0
qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
Expand All @@ -90,8 +81,8 @@ def dataset_id(self) -> str:
def docs_dataset_id(self) -> str:
return ir_datasets.docs_parent_id(self.dataset_id)

def setup(self, stage: Literal["fit", "validate", "test"] | None) -> None:
pass
def setup(self, stage: Literal["fit", "validate", "test"] | None) -> "IRDataset":
return self


class DataParallelIterableDataset(IterableDataset):
Expand Down Expand Up @@ -169,9 +160,8 @@ def __init__(
depth: int,
sample_size: int,
sampling_strategy: Literal["single_relevant", "top", "random"],
targets: (
Literal["relevance", "subtopic_relevance", "rank", "score"] | None
) = None,
targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
normalize_targets: bool = False,
) -> None:
self.run_path = None
if Path(run_path_or_id).is_file():
Expand All @@ -184,6 +174,7 @@ def __init__(
self.sample_size = sample_size
self.sampling_strategy = sampling_strategy
self.targets = targets
self.normalize_targets = normalize_targets

self.run: pd.DataFrame

Expand All @@ -194,9 +185,7 @@ def __init__(
"in the run file, but that are present in the qrels."
)

def setup(
self, stage: Literal["fit", "validate", "test"] | None = None
) -> "RunDataset":
def setup(self, stage: Literal["fit", "validate", "test"] | None = None) -> "RunDataset":
super().setup(stage)
if stage == "fit":
if self.targets is None:
Expand All @@ -216,9 +205,7 @@ def setup(
if len(run_query_ids.difference(qrels_query_ids)):
self.run = self.run[self.run["query_id"].isin(query_ids)]
self.run = self.run.merge(
self.qrels.loc[pd.IndexSlice[query_ids, :]].add_prefix(
"relevance_", axis=1
),
self.qrels.loc[pd.IndexSlice[query_ids, :]].add_prefix("relevance_", axis=1),
on=["query_id", "doc_id"],
how=(
"outer" if self._docs is None else "left"
Expand Down Expand Up @@ -262,7 +249,7 @@ def load_parquet(path: Path) -> pd.DataFrame:

@staticmethod
def load_json(path: Path) -> pd.DataFrame:
kwargs = {}
kwargs: Dict[str, Any] = {}
if ".jsonl" in path.suffixes:
kwargs["lines"] = True
kwargs["orient"] = "records"
Expand Down Expand Up @@ -299,11 +286,9 @@ def load_run(self) -> pd.DataFrame:
pass
if run_path is not None and run_path.suffixes[-1] in suffix_load_map:
run = suffix_load_map[run_path.suffixes[-1]](run_path)
elif self.ir_dataset.has_scoreddocs():
elif self.ir_dataset is not None and self.ir_dataset.has_scoreddocs():
run = pd.DataFrame(self.ir_dataset.scoreddocs_iter())
run["rank"] = run.groupby("query_id")["score"].rank(
"first", ascending=False
)
run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False)
run = run.sort_values(["query_id", "rank"])
else:
raise ValueError("Invalid run file format.")
Expand All @@ -312,18 +297,10 @@ def load_run(self) -> pd.DataFrame:
axis=1,
)
if "query" in run.columns:
self._queries = (
run.drop_duplicates("query_id")
.set_index("query_id")["query"]
.rename("text")
)
self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text")
run = run.drop("query", axis=1)
if "text" in run.columns:
self._docs = (
run.set_index("doc_id")["text"]
.map(lambda x: GenericDoc("", x))
.to_dict()
)
self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict()
run = run.drop("text", axis=1)
if self.depth != -1:
run = run[run["rank"] <= self.depth]
Expand All @@ -343,13 +320,9 @@ def qrels(self) -> pd.DataFrame | None:
qrels["iteration"] = self.run["iteration"]
else:
qrels["iteration"] = "0"
self._run = self.run.drop(
["relevance", "iteration"], axis=1, errors="ignore"
)
self._run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore")
qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(
level=-1
)
qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
qrels = qrels.droplevel(0, axis=1)
self._qrels = qrels
return self._qrels
Expand Down Expand Up @@ -382,18 +355,17 @@ def __getitem__(self, idx: int) -> RunSample:

targets = None
if self.targets is not None:
filtered = (
group.set_index("doc_id")
.loc[list(doc_ids)]
.filter(like=self.targets)
.fillna(0)
)
filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0)
if filtered.empty:
raise ValueError(f"targets `{self.targets}` not found in run file")
targets = torch.from_numpy(filtered.values)
if self.targets == "rank":
# invert ranks to be higher is better (necessary for loss functions)
targets = self.depth - targets + 1
if self.normalize_targets:
targets_min = targets.min()
targets_max = targets.max()
targets = (targets - targets_min) / (targets_max - targets_min)
qrels = None
if self.qrels is not None:
qrels = (
Expand Down Expand Up @@ -436,10 +408,7 @@ def parse_sample(
elif self.targets == "order":
targets = tuple([1.0] + [0.0] * (sample.num_docs - 1))
else:
raise ValueError(
f"invalid value for targets, got {self.targets}, "
"expected one of (order, score)"
)
raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)")
targets = targets[: self.num_docs]
else:
raise ValueError("Invalid sample type.")
Expand Down

0 comments on commit b6eba85

Please sign in to comment.