Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixes

- (dataset/huggingface_bridge) add optional `warn` parameter to `Dataset.set_infos()` to allow silent replacement of infos; `huggingface_dataset_to_plaid` now uses `warn=False` to prevent unnecessary warnings

### Removed

## [0.1.10] - 2025-10-29
Expand Down
2 changes: 1 addition & 1 deletion src/plaid/bridges/huggingface_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,7 +1535,7 @@ def parallel_convert(shard_path, n_workers):

infos = huggingface_description_to_infos(ds.description)

dataset.set_infos(infos)
dataset.set_infos(infos, warn=False)

problem_definition = huggingface_description_to_problem_definition(ds.description)

Expand Down
7 changes: 5 additions & 2 deletions src/plaid/containers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,11 +931,12 @@ def add_infos(self, cat_key: str, infos: dict[str, str]) -> None:
for key, value in infos.items():
self._infos[cat_key][key] = value

def set_infos(self, infos: dict[str, dict[str, str]]) -> None:
def set_infos(self, infos: dict[str, dict[str, str]], warn: bool = True) -> None:
"""Set information to the :class:`Dataset <plaid.containers.dataset.Dataset>`, overwriting the existing one.

Args:
infos (dict[str,dict[str,str]]): Information to associate with this data set (Dataset).
warn (bool, optional): If True, warns when replacing existing infos. Defaults to True.

Raises:
KeyError: Invalid category key format in provided infos.
Expand Down Expand Up @@ -963,7 +964,9 @@ def set_infos(self, infos: dict[str, dict[str, str]]) -> None:
f"{info_key=} not among authorized keys. Maybe you want to try among these keys {AUTHORIZED_INFO_KEYS[cat_key]}"
)

if len(self._infos) > 0:
# Check if there are any non-plaid infos being replaced
has_user_infos = any(key != "plaid" for key in self._infos.keys())
if has_user_infos and warn:
logger.warning("infos not empty, replacing it anyway")
self._infos = copy.deepcopy(infos)

Expand Down
14 changes: 14 additions & 0 deletions tests/bridges/test_huggingface_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,20 @@ def test_huggingface_dataset_to_plaid(self, hf_dataset):
ds, _ = huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset)
self.assert_plaid_dataset(ds)

def test_huggingface_dataset_to_plaid_no_warning(self, hf_dataset, caplog):
"""Test that huggingface_dataset_to_plaid does not trigger infos replacement warning."""
import logging

with caplog.at_level(logging.WARNING):
ds, _ = huggingface_bridge.huggingface_dataset_to_plaid(
hf_dataset, verbose=False
)

# Should not warn about replacing infos
assert "infos not empty, replacing it anyway" not in caplog.text
# Dataset should still be valid
self.assert_plaid_dataset(ds)

def test_huggingface_dataset_to_plaid_with_ids_binary(self, hf_dataset):
huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset, ids=[0, 1])

Expand Down
21 changes: 21 additions & 0 deletions tests/containers/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,27 @@ def test_set_infos(self, dataset, infos):
{"legal": {"illegal_info_key": "PLAID2", "license": "BSD-3"}}
)

def test_set_infos_warn_parameter(self, dataset, infos, caplog):
"""Test the warn parameter for silent replacement of infos."""
import logging

# First set should not warn (no user infos to replace)
with caplog.at_level(logging.WARNING):
dataset.set_infos(infos)
assert "infos not empty, replacing it anyway" not in caplog.text

# Second set with warn=True (default) should warn
caplog.clear()
with caplog.at_level(logging.WARNING):
dataset.set_infos({"legal": {"owner": "Owner2"}})
assert "infos not empty, replacing it anyway" in caplog.text

# Third set with warn=False should not warn
caplog.clear()
with caplog.at_level(logging.WARNING):
dataset.set_infos({"legal": {"owner": "Owner3"}}, warn=False)
assert "infos not empty, replacing it anyway" not in caplog.text

def test_get_infos(self, dataset):
assert dataset.get_infos()["plaid"]["version"] == str(
Version(plaid.__version__)
Expand Down
Loading