From 3321f64e0a78c6c92a466f04a0f0428faeb6699c Mon Sep 17 00:00:00 2001 From: narumi Date: Sat, 9 Nov 2024 16:37:08 +0800 Subject: [PATCH] fix flatten sep --- src/mlconfig/conf.py | 2 +- tests/test_conf.py | 59 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/mlconfig/conf.py b/src/mlconfig/conf.py index db52fd7..400ee33 100644 --- a/src/mlconfig/conf.py +++ b/src/mlconfig/conf.py @@ -100,7 +100,7 @@ def flatten(data: dict[str, Any], prefix: str | None = None, sep: str = ".") -> key = prefix + sep + key if isinstance(value, dict): - d.update(flatten(value, prefix=key)) + d.update(flatten(value, prefix=key, sep=sep)) continue d[key] = value diff --git a/tests/test_conf.py b/tests/test_conf.py index 28000e6..486f51f 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -69,3 +69,62 @@ def test_getcls(conf): ) def test_flatten(test_input: dict, expected: dict) -> None: assert flatten(test_input) == expected + + +def test_getcls_valid(conf): + assert getcls(conf["a"]) == Point + assert getcls(conf["op"]) == add + + +def test_getcls_invalid_key_type(conf): + _key = "name" + conf_invalid = conf.copy() + conf_invalid["a"][_key] = 123 # Invalid key type + with pytest.raises(ValueError, match="key 123 must be a string"): + getcls(conf_invalid["a"]) + + +def test_getcls_key_not_found(conf): + _key = "name" + conf_invalid = conf.copy() + conf_invalid["a"][_key] = "NonExistentKey" # Key not in registry + with pytest.raises(ValueError, match="key NonExistentKey not found in registry"): + getcls(conf_invalid["a"]) + + +def test_register_duplicate(): + with pytest.raises(ValueError, match="duplicate name Point found"): + @register(name="Point") + class PointDuplicate: + def __init__(self, x, y): + self.x = x + self.y = y + +def test_instantiate_with_kwargs(conf, obj): + a = instantiate(conf.a, y=10) + assert a.y == 10 + assert a.x == obj["x1"] + +def test_instantiate_with_args(conf, obj): + conf_with_args = {"name": "Point"} + a = instantiate(conf_with_args, 5, 6) + assert a.x == 5 + assert a.y == 6 + +def test_flatten_with_empty_dict(): + assert flatten({}) == {} + +def test_flatten_with_nested_dict(): + nested_dict = {"a": {"b": {"c": "d"}}} + expected = {"a.b.c": "d"} + assert flatten(nested_dict) == expected + +def test_flatten_with_prefix(): + nested_dict = {"a": {"b": "c"}} + expected = {"prefix.a.b": "c"} + assert flatten(nested_dict, prefix="prefix") == expected + +def test_flatten_with_custom_separator(): + nested_dict = {"a": {"b": "c"}} + expected = {"a/b": "c"} + assert flatten(nested_dict, sep="/") == expected