Skip to content

Commit

Permalink
Merge pull request #47 from MLRichter/feature/rf_filtering
Browse files Browse the repository at this point in the history
Feature/rf filtering
  • Loading branch information
MLRichter authored Feb 9, 2022
2 parents 9038925 + 4f9bf5e commit 53c2886
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 30 deletions.
43 changes: 13 additions & 30 deletions rfa_toolbox/utils/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,17 @@
from rfa_toolbox.graphs import EnrichedNetworkNode


def _remove_duplicates(nodes: List[EnrichedNetworkNode]) -> List[EnrichedNetworkNode]:
result = []
for node in nodes:
if node.is_in(result):
continue
else:
result.append(node)
return result


def obtain_all_nodes(
output_node: EnrichedNetworkNode, search_from_output: bool = False
) -> List[EnrichedNetworkNode]:
def obtain_all_nodes(output_node: EnrichedNetworkNode) -> List[EnrichedNetworkNode]:
"""Fetch all nodes from a single node of the compute graph.
Args:
output_node: output node of the graph
search_from_output: False by default. If True,
the nodes will be searched
using the BFS-Algorithm. If False,
the internal registry of the node will be used,
which may be dangerous if more than one
input-node exists.
Returns:
A List containing all EnrichedNetworkNodes.
"""
if search_from_output:
all_nodes = [output_node]
for pred in output_node.predecessors:
all_nodes.extend(obtain_all_nodes(pred, False))
return _remove_duplicates(all_nodes)
else:
return output_node.all_layers
return output_node.all_layers


def obtain_border_layers(
Expand Down Expand Up @@ -147,9 +123,10 @@ def filters_non_infinite_rf_sizes(

def input_resolution_range(
graph: EnrichedNetworkNode,
filter_all_inf_rf: bool = False,
filter_all_inf_rf: bool = True,
filter_kernel_size_1: bool = False,
cardinality: int = 2,
lower_bound: bool = False,
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""Obtain the smallest and largest feasible input resolution.
The smallest feasible input resolution is defined as the input smallest input
Expand All @@ -164,7 +141,7 @@ def input_resolution_range(
a to small receptive field size.
Args:
graph: The neural network
graph: The neural network
filter_all_inf_rf: filters ALL infinite receptive field sizes before
computing the result, this may come in handy
if you want to ignore the influence
Expand All @@ -176,7 +153,12 @@ def input_resolution_range(
misleading, since these types of modules are not realy
build to extract features from the image.
This functionality is disabled by default.
cardinality: The tensor shape, which is 2D by default.
cardinality: The tensor shape, which is 2D by default.
lower_bound: Disabled by default. If disabled, returns the lowest
resolution which utilizes the entire
network receptive field expansion.
If enabled it returns the lowest resolution exspected
to yield no unproductive, weighted layers.
Returns:
Smallest and largest feasible input resolution.
Expand Down Expand Up @@ -223,9 +205,10 @@ def find_max(
if not second_largest:
return max(rf_no_tuples)
else:
rf_no_tuples = set(rf_no_tuples)
rf_no_tuples.remove(max(rf_no_tuples))
return max(rf_no_tuples)

r_max = tuple(find_max(rf_max, i) for i in range(cardinality))
r_min = tuple(find_max(rf_min, i) for i in range(cardinality))
r_min = tuple(find_max(rf_min, i, lower_bound) for i in range(cardinality))
return r_min, r_max
7 changes: 7 additions & 0 deletions tests/test_graph/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,13 @@ def test_with_higher_degree_tensor(self, sequential_network):
assert len(r_max) == cardinality
assert len(r_min) == cardinality

def test_with_scalar_receptive_field_sizes_lower_bound(self, sequential_network):
r_min, r_max = input_resolution_range(sequential_network, lower_bound=True)
assert len(r_max) == 2
assert len(r_min) == 2
assert r_min == (11, 11)
assert r_max == (13, 13)

def test_with_scalar_receptive_field_sizes(self, sequential_network):
r_min, r_max = input_resolution_range(sequential_network)
assert len(r_max) == 2
Expand Down

0 comments on commit 53c2886

Please sign in to comment.