Skip to content

Commit

Permalink
fix: build fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
MLRichter committed Feb 9, 2022
1 parent 1f901f7 commit 4f9bf5e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 26 deletions.
28 changes: 2 additions & 26 deletions rfa_toolbox/utils/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,17 @@
from rfa_toolbox.graphs import EnrichedNetworkNode


def _remove_duplicates(nodes: List[EnrichedNetworkNode]) -> List[EnrichedNetworkNode]:
result = []
for node in nodes:
if node.is_in(result):
continue
else:
result.append(node)
return result


def obtain_all_nodes(
output_node: EnrichedNetworkNode, search_from_output: bool = False
) -> List[EnrichedNetworkNode]:
def obtain_all_nodes(output_node: EnrichedNetworkNode) -> List[EnrichedNetworkNode]:
"""Fetch all nodes from a single node of the compute graph.
Args:
output_node: output node of the graph
search_from_output: False by default. If True,
the nodes will be searched
using the BFS-Algorithm. If False,
the internal registry of the node will be used,
which may be dangerous if more than one
input-node exists.
Returns:
A List containing all EnrichedNetworkNodes.
"""
if search_from_output:
all_nodes = [output_node]
for pred in output_node.predecessors:
all_nodes.extend(obtain_all_nodes(pred, False))
return _remove_duplicates(all_nodes)
else:
return output_node.all_layers
return output_node.all_layers


def obtain_border_layers(
Expand Down
7 changes: 7 additions & 0 deletions tests/test_graph/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,13 @@ def test_with_higher_degree_tensor(self, sequential_network):
assert len(r_max) == cardinality
assert len(r_min) == cardinality

def test_with_scalar_receptive_field_sizes_lower_bound(self, sequential_network):
r_min, r_max = input_resolution_range(sequential_network, lower_bound=True)
assert len(r_max) == 2
assert len(r_min) == 2
assert r_min == (11, 11)
assert r_max == (13, 13)

def test_with_scalar_receptive_field_sizes(self, sequential_network):
r_min, r_max = input_resolution_range(sequential_network)
assert len(r_max) == 2
Expand Down

0 comments on commit 4f9bf5e

Please sign in to comment.