Skip to content

Commit

Permalink
Merge branch 'main' into docs
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ authored Feb 14, 2025
2 parents 02d0708 + 39df826 commit 8e68eb3
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 10 deletions.
51 changes: 49 additions & 2 deletions tests/integration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,52 @@ def test_list_groups(proj_path, runner):
assert result.exit_code == 0


if __name__ == "__main__":
test_list_groups(None, None)
def test_list_multi_nested_groups(proj_path, runner):
proj = zntrack.Project()

with proj:
zntrack.examples.ParamsToOuts(params=15)
zntrack.examples.ParamsToOuts(params=15)

with proj.group("dynamics"):
zntrack.examples.ParamsToOuts(params=15)
zntrack.examples.ParamsToOuts(params=15)

with proj.group("dynamics", "400K"):
zntrack.examples.ParamsToOuts(params=15)
zntrack.examples.ParamsToOuts(params=15)

with proj.group("dynamics", "400K", "B"):
zntrack.examples.ParamsToOuts(params=15)
zntrack.examples.ParamsToOuts(params=15)

proj.build()

true_groups = {
"dynamics": [
{
"400K": [
{
"B": [
"ParamsToOuts -> dynamics_400K_B_ParamsToOuts",
"ParamsToOuts_1 -> dynamics_400K_B_ParamsToOuts_1",
]
},
"ParamsToOuts -> dynamics_400K_ParamsToOuts",
"ParamsToOuts_1 -> dynamics_400K_ParamsToOuts_1",
]
},
"ParamsToOuts -> dynamics_ParamsToOuts",
"ParamsToOuts_1 -> dynamics_ParamsToOuts_1",
],
"nodes": ["ParamsToOuts", "ParamsToOuts_1"],
}

groups, _ = utils.cli.get_groups(remote=proj_path, rev=None)
assert groups == true_groups

result = runner.invoke(app, ["list", proj_path.as_posix()])
assert result.exit_code == 0
groups = yaml.safe_load(result.stdout)
assert groups == true_groups
assert result.exit_code == 0
1 change: 1 addition & 0 deletions tests/unit_tests/test_node_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def test_grouped_duplicate_named_node(proj_path):
with project.group("grp1"):
MyNode(name="A")


@pytest.mark.parametrize("char", ["@:", "#", "$", ":", "/", "\\", ".", ";", ","])
def test_forbidden_node_names(proj_path, char):
"""Test that nodes with forbidden names cannot be created"""
Expand Down
2 changes: 2 additions & 0 deletions zntrack/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

log = logging.getLogger(__name__)


def _name_setter(self, attr_name: str, value: str) -> None:
"""Check if the node name is valid."""

Expand All @@ -43,6 +44,7 @@ def _name_setter(self, attr_name: str, value: str) -> None:

self.__dict__[attr_name] = value


def _name_getter(self, attr_name: str) -> str:
"""Retrieve the name of a node based on the current graph context.
Expand Down
29 changes: 21 additions & 8 deletions zntrack/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_groups(remote, rev) -> Tuple[dict, list]:
Returns:
-------
groups : dict
a nested dictionary with the group names as keys and the nodes in each group as
A nested dictionary with the group names as keys and the nodes in each group as
values. Contains "short-name -> long-name" if inside a group.
node_names: list
A list of all node names in the project.
Expand All @@ -34,19 +34,31 @@ def get_groups(remote, rev) -> Tuple[dict, list]:
node_names = []

def add_to_group(groups, grp_names, node_name):
"""Recursively add node_name into the correct nested group structure."""
if not grp_names:
return

current_group = grp_names[0]

# If this is the last level, add the node directly
if len(grp_names) == 1:
if grp_names[0] not in groups:
groups[grp_names[0]] = []
groups[grp_names[0]].append(node_name)
if current_group not in groups:
groups[current_group] = []
groups[current_group].append(node_name)
else:
if grp_names[0] not in groups:
groups[grp_names[0]] = [{}]
add_to_group(groups[grp_names[0]][0], grp_names[1:], node_name)
# Ensure the current group contains a dictionary inside a list
if current_group not in groups:
groups[current_group] = [{}]
elif not isinstance(groups[current_group][0], dict):
groups[current_group].insert(0, {})

add_to_group(groups[current_group][0], grp_names[1:], node_name)

for node_name, node_config in config.items():
nwd = pathlib.Path(node_config["nwd"]["value"])
grp_names = nwd.parent.as_posix().split("/")[1:]
if len(grp_names) == 0:

if not grp_names:
node_names.append(node_name)
grp_names = ["nodes"]
else:
Expand All @@ -55,6 +67,7 @@ def add_to_group(groups, grp_names, node_name):

node_names.append(f"{'_'.join(grp_names)}_{node_name}")
node_name = f"{node_name} -> {node_names[-1]}"

add_to_group(true_groups, grp_names, node_name)

return true_groups, node_names

0 comments on commit 8e68eb3

Please sign in to comment.