Skip to content

Commit

Permalink
Merge pull request #131 from narumiruna/tests
Browse files Browse the repository at this point in the history
FIX: fix flatten spe
  • Loading branch information
narumiruna authored Nov 9, 2024
2 parents 586e5ac + 3321f64 commit e41d472
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/mlconfig/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def flatten(data: dict[str, Any], prefix: str | None = None, sep: str = ".") ->
key = prefix + sep + key

if isinstance(value, dict):
d.update(flatten(value, prefix=key))
d.update(flatten(value, prefix=key, sep=sep))
continue

d[key] = value
Expand Down
59 changes: 59 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,62 @@ def test_getcls(conf):
)
def test_flatten(test_input: dict, expected: dict) -> None:
assert flatten(test_input) == expected


def test_getcls_valid(conf):
assert getcls(conf["a"]) == Point
assert getcls(conf["op"]) == add


def test_getcls_invalid_key_type(conf):
_key = "name"
conf_invalid = conf.copy()
conf_invalid["a"][_key] = 123 # Invalid key type
with pytest.raises(ValueError, match="key 123 must be a string"):
getcls(conf_invalid["a"])


def test_getcls_key_not_found(conf):
_key = "name"
conf_invalid = conf.copy()
conf_invalid["a"][_key] = "NonExistentKey" # Key not in registry
with pytest.raises(ValueError, match="key NonExistentKey not found in registry"):
getcls(conf_invalid["a"])


def test_register_duplicate():
with pytest.raises(ValueError, match="duplicate name Point found"):
@register(name="Point")
class PointDuplicate:
def __init__(self, x, y):
self.x = x
self.y = y

def test_instantiate_with_kwargs(conf, obj):
a = instantiate(conf.a, y=10)
assert a.y == 10
assert a.x == obj["x1"]

def test_instantiate_with_args(conf, obj):
conf_with_args = {"name": "Point"}
a = instantiate(conf_with_args, 5, 6)
assert a.x == 5
assert a.y == 6

def test_flatten_with_empty_dict():
assert flatten({}) == {}

def test_flatten_with_nested_dict():
nested_dict = {"a": {"b": {"c": "d"}}}
expected = {"a.b.c": "d"}
assert flatten(nested_dict) == expected

def test_flatten_with_prefix():
nested_dict = {"a": {"b": "c"}}
expected = {"prefix.a.b": "c"}
assert flatten(nested_dict, prefix="prefix") == expected

def test_flatten_with_custom_separator():
nested_dict = {"a": {"b": "c"}}
expected = {"a/b": "c"}
assert flatten(nested_dict, sep="/") == expected

0 comments on commit e41d472

Please sign in to comment.