1
1
import math
2
2
import time
3
3
from typing import Optional , Union , Sequence
4
- from collections import defaultdict
5
4
import networkx as nx
6
5
from ..metrics import metrics
7
6
8
7
# Get all the functions in the metrics module
9
- _metric_functions = [func for func in dir (metrics ) if callable (getattr (metrics , func )) and not func .startswith ("__" )]
8
+ _metric_functions = [
9
+ func
10
+ for func in dir (metrics )
11
+ if callable (getattr (metrics , func )) and not func .startswith ("__" )
12
+ ]
10
13
11
14
# Generate the DEFAULT_WEIGHTS dictionary
12
15
DEFAULT_WEIGHTS = {func : 1 for func in _metric_functions }
13
16
14
17
# Generate the METRICS dictionary
15
18
METRICS = {func : {"func" : getattr (metrics , func )} for func in _metric_functions }
16
19
20
+
17
21
class MetricsSuite :
18
22
"""A suite for calculating several metrics for graph drawing aesthetics, as well as methods for combining these into a single cost function.
19
23
Takes as an argument a path to a GML or GraphML file, or a NetworkX Graph object. Also takes as an argument a dictionary of metric:weight key/values.
@@ -38,7 +42,7 @@ def __init__(
38
42
# Dictionary mapping metric names to their functions, values, and weights
39
43
self .metrics = METRICS .copy ()
40
44
for k in self .metrics .keys ():
41
- self .metrics [k ].update ({"weight" :0 , "value" : None , "is_calculated" : False })
45
+ self .metrics [k ].update ({"weight" : 0 , "value" : None , "is_calculated" : False })
42
46
43
47
# Check all metrics given are valid and assign weights
44
48
self .initial_weights = self .set_weights (metric_weights )
@@ -62,19 +66,21 @@ def __init__(
62
66
raise TypeError (
63
67
f"'graph' must be a string representing a path to a GML or GraphML file, or a NetworkX Graph object, not { type (graph )} "
64
68
)
65
-
69
+
66
70
if sym_tolerance < 0 :
67
- raise ValueError (f "sym_tolerance must be positive." )
71
+ raise ValueError ("sym_tolerance must be positive." )
68
72
69
73
self .sym_tolerance = sym_tolerance
70
74
71
75
if sym_threshold < 0 :
72
- raise ValueError (f "sym_threshold must be positive." )
76
+ raise ValueError ("sym_threshold must be positive." )
73
77
74
78
self .sym_threshold = sym_threshold
75
79
76
80
def set_weights (self , metric_weights : Sequence [float ]):
77
- metrics_to_remove = [metric for metric , weight in metric_weights .items () if weight <= 0 ]
81
+ metrics_to_remove = [
82
+ metric for metric , weight in metric_weights .items () if weight <= 0
83
+ ]
78
84
79
85
if any (metric_weights [metric ] < 0 for metric in metric_weights ):
80
86
raise ValueError ("Metric weights must be positive." )
@@ -85,8 +91,10 @@ def set_weights(self, metric_weights: Sequence[float]):
85
91
for metric in metric_weights :
86
92
self .metrics [metric ]["weight" ] = metric_weights [metric ]
87
93
88
- return {metric : weight for metric , weight in metric_weights .items () if weight > 0 }
89
-
94
+ return {
95
+ metric : weight for metric , weight in metric_weights .items () if weight > 0
96
+ }
97
+
90
98
def weighted_prod (self ):
91
99
"""Returns the weighted product of all metrics. Should NOT be used as a cost function - may be useful for comparing graphs."""
92
100
return math .prod (
@@ -114,17 +122,19 @@ def load_graph_test(self, nxg=nx.sedgewick_maze_graph):
114
122
115
123
nx .set_node_attributes (G , pos )
116
124
return G
117
-
125
+
118
126
def reset_metrics (self ):
119
127
for metric in self .metrics :
120
128
self .metrics [metric ]["value" ] = None
121
129
self .metrics [metric ]["is_calculated" ] = False
122
-
130
+
123
131
def calculate_metric (self , metric : str = None ):
124
132
"""Calculate the value of the given metric by calling the associated function."""
125
133
if metric is None :
126
- raise ValueError ("No metric provided. Did you mean to call calculate_metrics()?" )
127
-
134
+ raise ValueError (
135
+ "No metric provided. Did you mean to call calculate_metrics()?"
136
+ )
137
+
128
138
if not self .metrics [metric ]["is_calculated" ]:
129
139
self .metrics [metric ]["value" ] = self .metrics [metric ]["func" ](self ._graph )
130
140
self .metrics [metric ]["is_calculated" ] = True
@@ -141,7 +151,9 @@ def calculate_metrics(self):
141
151
self .calculate_metric (metric )
142
152
n_metrics += 1
143
153
end_time = time .perf_counter ()
144
- print (f"Calculated { n_metrics } metrics in { end_time - start_time :0.3f} seconds." )
154
+ print (
155
+ f"Calculated { n_metrics } metrics in { end_time - start_time :0.3f} seconds."
156
+ )
145
157
146
158
def combine_metrics (self ):
147
159
"""Combine several metrics based on the given multiple criteria decision analysis technique."""
@@ -172,4 +184,4 @@ def metric_table(self):
172
184
for k , v in self .metrics .items ():
173
185
metrics [k ] = v ["value" ]
174
186
metrics ["Combined" ] = combined
175
- return metrics
187
+ return metrics
0 commit comments