Skip to content

Commit

Permalink
Bug circular import (#280)
Browse files Browse the repository at this point in the history
* bug: circular arg fixed, class rename

* refactor: simplified evaluation and editted documentation

---------

Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Nov 13, 2024
1 parent 3cec969 commit 8637c2e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 126 deletions.
157 changes: 54 additions & 103 deletions src/deep_neurographs/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
@author: Anna Grim
@email: anna.grim@alleninstitute.org
Evaluates performance of edge classifiation model.
Evaluates performance of proposal classifiation model.
"""

Expand All @@ -19,95 +19,63 @@
]


def init_stats():
def run_evaluation(fragments_graph, proposals, accepts):
"""
Initializes a dictionary that stores stats computes by routines in this
module.
Parameters
----------
None
Returns
-------
dict
Dictionary that stores stats computes by routines in this module.
"""
return dict([(metric, []) for metric in METRICS_LIST])


def run_evaluation(neurograph, accepts, proposals):
"""
Runs an evaluation on the accuracy of the predictions generated by an edge
Evaluates the accuracy of predictions made by a proposal
classication model.
Parameters
----------
neurographs : list[NeuroGraph]
Predicted neurographs.
accepts : list
fragments_graphs : FragmentsGraph
Graph generated from fragments of a predicted segmentation.
proposals : list[frozenset]
Proposals classified by model.
accepts : list[frozenset]
Accepted proposals.
proposals : list
Proposals that were classified as either accept or reject.
Returns
-------
dict
Dictionary that stores the accuracy of the edge classification model
on all edges (i.e. "Overall"), simple edges, and complex edges. The
metrics contained in this dictionary are identical to "METRICS_LIST"].
Dictionary that stores statistics calculated for all proposal
predictions and separately for simple and complex proposals, as
specified in "METRICS_LIST".
"""
# Initializations
stats = {
"Overall": init_stats(),
"Simple": init_stats(),
"Complex": init_stats(),
}
stats = dict()
simple_proposals = fragments_graph.simple_proposals()
complex_proposals = fragments_graph.complex_proposals()

# Evaluation
overall_stats = get_stats(neurograph, proposals, accepts)

simple_stats = get_stats(
neurograph, neurograph.simple_proposals(), accepts
)

complex_stats = get_stats(
neurograph, neurograph.complex_proposals(), accepts
)

# Store results
for metric in METRICS_LIST:
stats["Overall"][metric].append(overall_stats[metric])
stats["Simple"][metric].append(simple_stats[metric])
stats["Complex"][metric].append(complex_stats[metric])
stats["Overall"] = get_stats(fragments_graph, proposals, accepts)
stats["Simple"] = get_stats(fragments_graph, simple_proposals, accepts)
stats["Complex"] = get_stats(fragments_graph, complex_proposals, accepts)
return stats


def get_stats(neurograph, proposals, accepts):
def get_stats(fragments_graph, proposals, accepts):
"""
Accuracy of the predictions generated by an edge classication model on a
given block and "edge_type" (e.g. overall, simple, or complex).
Computes statistics that reflect the accuracy of the predictions made by
a proposal classication model.
Parameters
----------
neurograph : NeuroGraph
Predicted neurograph
proposals : set[frozenset]
Set of edge proposals for a given "edge_type".
fragments_graph : FragmentsGraph
Graph generated from fragments of a predicted segmentation.
proposals : list[frozenset]
List of proposals of a specified "proposal_type".
accepts : numpy.ndarray
Binary predictions of edges generated by classifcation model.
Accepted proposals.
Returns
-------
dict
Results of evaluation where the keys are identical to "METRICS_LIST".
"""
n_pos = len([e for e in proposals if e in neurograph.target_edges])
n_pos = len([e for e in proposals if e in fragments_graph.gt_accepts])
a_baseline = n_pos / (len(proposals) if len(proposals) > 0 else 1)
tp, fp, a, p, r, f1 = get_accuracy(neurograph, proposals, accepts)
tp, fp, a, p, r, f1 = get_accuracy(fragments_graph, proposals, accepts)
stats = {
"# splits fixed": tp,
"# merges created": fp,
Expand All @@ -120,80 +88,63 @@ def get_stats(neurograph, proposals, accepts):
return stats


def get_accuracy(neurograph, proposals, accepts):
def get_accuracy(fragments_graph, proposals, accepts):
"""
Computes the following metrics for a given set of predicted edges:
(1) true positives, (2) false positive, (3) precision, (4) recall, and
(5) f1-score.
Computes the following metrics for a given set of predicted proposals:
(1) true positives, (2) false positive, (3) accuracy, (4) precision,
(5) recall, and (6) f1-score.
Parameters
----------
neurograph : NeuroGraph
Predicted neurograph
fragments_graph : FragmentsGraph
Graph generated from fragments of a predicted segmentation.
proposals : set[frozenset]
Set of edge proposals for a given "edge_type".
List of proposals of a specified "proposal_type".
accepts : list
Accepted proposals.
Returns
-------
float
Number of true positives.
float
Number of false positives.
float
Precision.
float
Recall.
float
F1-score.
float, float, float, float, float, float
Number true positives, number of false positives, accuracy, precision,
recall, and F1-score.
"""
tp, tn, fp, fn = get_accuracy_counts(neurograph, proposals, accepts)
tp, tn, fp, fn = get_detection_cnts(fragments_graph, proposals, accepts)
a = (tp + tn) / len(proposals) if len(proposals) else 1
p = 1 if tp + fp == 0 else tp / (tp + fp)
r = 1 if tp + fn == 0 else tp / (tp + fn)
f1 = (2 * r * p) / max(r + p, 1e-3)
return tp, fp, a, p, r, f1


def get_accuracy_counts(neurograph, proposals, accepts):
def get_detection_cnts(fragments_graph, proposals, accepts):
"""
Computes the following values: (1) true positives, (2) false positive, and
(3) false negatives.
Computes the following values: (1) true positives, (2) true negatives,
(3) false positive, and (4) false negatives.
Parameters
----------
neurograph : NeuroGraph
Predicted neurograph
fragments_graph : FragmentsGraph
Graph generated from fragments of a predicted segmentation.
proposals : set[frozenset]
Set of edge proposals for a given "edge_type".
List of proposals of a specified "proposal_type".
accepts : list
Accepted proposals.
Returns
-------
float
Number of true positives.
float
Number of false positives.
float
Number of false negatives.
float, float, float, float
Number of true positives, true negatives, false positives, and false
negatives.
"""
tp = 0
tn = 0
fp = 0
fn = 0
for edge in proposals:
if edge in neurograph.target_edges:
if edge in accepts:
tp += 1
else:
fn += 1
tp, tn, fp, fn = 0, 0, 0, 0
for p in proposals:
if p in fragments_graph.gt_accepts:
tp += 1 if p in accepts else 0
fn += 1 if p not in accepts else 0
else:
if edge in accepts:
fp += 1
else:
tn += 1
fp += 1 if p in accepts else 0
tn += 1 if p not in accepts else 0
return tp, tn, fp, fn
54 changes: 38 additions & 16 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def save_to_s3(self):
local_path = os.path.join(self.output_dir, filename)
s3_path = os.path.join(self.s3_dict["prefix"], filename)
util.write_to_s3(local_path, bucket_name, s3_path)
print("Results written to S3 prefix -->", prefix)
print("Results written to S3 prefix -->", self.s3_dict["prefix"])

# --- io ---
def save_connections(self, round_id=None):
Expand Down Expand Up @@ -529,9 +529,8 @@ def __init__(

# Model
self.model = ml_util.load_model(model_path)
if self.is_gnn:
if self.is_gnn and "cuda" in device:
self.model = self.model.to(self.device)
self.model.eval()

def run(self, neurograph, proposals):
"""
Expand Down Expand Up @@ -638,9 +637,9 @@ def predict(self, dataset):
Parameters
----------
data : ...
Dataset on which the model inference is to be run.
data : HeteroGeneousDataset
Dataset containing graph information, including feature matrices
and other relevant attributes needed for GNN input.
Returns
-------
dict
Expand All @@ -650,16 +649,7 @@ def predict(self, dataset):
"""
# Get predictions
if self.is_gnn:
with torch.no_grad():
# Get inputs
n = len(dataset.data["proposal"]["y"])
x, edge_index, edge_attr = gnn_util.get_inputs(
dataset.data, device=self.device
)

# Run model
preds = sigmoid(self.model(x, edge_index, edge_attr))
preds = toCPU(preds[0:n, 0])
preds = predict_with_gnn(self.model, dataset.data, self.device)
else:
preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1])

Expand All @@ -669,6 +659,38 @@ def predict(self, dataset):


# --- Accepting Proposals ---
def predict_with_gnn(model, data, device=None):
"""
Generates predictions using a Graph Neural Network (GNN) on the given
dataset.
Parameters:
----------
model : torch.nn.Module
GNN model used to generate predictions. It should accept node
features, edge indices, and edge attributes as input and output
predictions.
data : dict
Dataset containing graph information, including feature matrices
and other relevant attributes needed for GNN input.
device : str, optional
The device (CPU or GPU) on which the prediction will be run. The
default is None.
Returns:
-------
torch.Tensor
A tensor of predictions, converted to CPU, for the 'proposal' entries
in the dataset. Only the relevant predictions for 'proposal' nodes are
returned.
"""
with torch.no_grad():
x, edge_index, edge_attr = gnn_util.get_inputs(data, device)
preds = sigmoid(model(x, edge_index, edge_attr))
return toCPU(preds[0:len(data["proposal"]["y"]), 0])


def get_accepts(neurograph, preds, threshold, high_threshold=0.9):
"""
Determines which proposals to accept based on prediction scores and the
Expand Down
19 changes: 12 additions & 7 deletions src/deep_neurographs/utils/ml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ def load_model(path):
...
"""
return joblib.load(path) if ".joblib" in path else torch.load(path)
if ".joblib" in path:
model = joblib.load(path)
else:
model = torch.load(path)
model.eval()
return model


def save_model(path, model, model_type):
Expand Down Expand Up @@ -62,22 +67,22 @@ def save_model(path, model, model_type):

# --- dataset utils ---
def init_dataset(
neurograph,
fragments_graph,
features,
is_gnn=True,
is_multimodal=False,
computation_graph=None
):
"""
Initializes a dataset given features generated from some set of proposals
and neurograph.
and fragments_graph.
Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that "features" were generated from.
features : dict
Feaures generated from some set of proposals and "neurograph".
Feaures generated from some set of proposals and "fragments_graph".
model_type : str
Type of machine learning model used to perform inference.
computation_graph : networkx.Graph, optional
Expand All @@ -93,10 +98,10 @@ def init_dataset(
if is_gnn:
assert computation_graph is not None, "Must input computation graph!"
dataset = heterograph_datasets.init(
neurograph, features, computation_graph
fragments_graph, features, computation_graph
)
else:
dataset = datasets.init(neurograph, features)
dataset = datasets.init(fragments_graph, features)
return dataset


Expand Down

0 comments on commit 8637c2e

Please sign in to comment.