Skip to content

Commit

Permalink
Upd notebook (#121)
Browse files Browse the repository at this point in the history
* feat: gnn inference and documentation

* bug: inference

---------

Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
  • Loading branch information
anna-grim and anna-grim authored Apr 14, 2024
1 parent 8b113ca commit 8690b36
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 34 deletions.
6 changes: 5 additions & 1 deletion src/deep_neurographs/machine_learning/graph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ class MLP(torch.nn.Module):
def __init__(self, input_channels):
super().__init__()
self.linear1 = Linear(input_channels, input_channels // 2)
self.linear2 = Linear(input_channels // 2, 1)
self.linear2 = Linear(input_channels // 2, input_channels // 2)
self.linear3 = Linear(input_channels // 2, 1)
self.ELU = ELU()

def forward(self, x, edge_index):
x = self.linear1(x)
x = self.ELU(x)
x = F.dropout(x, p=0.25)
x = self.linear2(x)
x = self.ELU(x)
x = F.dropout(x, p=0.25)
x = self.linear3(x)
return x
47 changes: 22 additions & 25 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def run_without_seeds(

# Report progress
if i > progress_cnt * chunk_size and progress_bar:
progress_cnt, t1 = report_progress(
progress_cnt, t1 = utils.report_progress(
i, n_batches, chunk_size, progress_cnt, t0, t1
)
t0, t1 = utils.init_timers()
Expand Down Expand Up @@ -270,44 +270,41 @@ def run_model(dataset, model, model_type):
hat_y_i = np.array(hat_y_i)
hat_y.extend(hat_y_i.tolist())
else:
data = dataset["dataset"]
hat_y = model.predict_proba(data["inputs"])[:, 1]
return np.array(hat_y)


def run_graph_model(graph_data, model):
# Run model
model.eval()
x, edge_index = toGPU(graph_data.data)
hat_y = model(x, edge_index)
with torch.no_grad():
hat_y = sigmoid(model(x, edge_index))

# Reformat pred
idx = graph_data.n_proposals
hat_y = ml_utils.toCPU(hat_y[0:idx, 0])
return ml_utils.sigmoid(hat_y)


# Utils
def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)
utils.progress_bar(current, total, eta=eta, runtime=runtime)
return cnt + 1, time()


def get_eta(current, total, chunk_size, t0, return_str=True):
chunk_runtime = time() - t0
remaining = total - current
eta = remaining * (chunk_runtime / chunk_size)
t, unit = utils.time_writer(eta)
return f"{round(t, 4)} {unit}" if return_str else eta

return hat_y

def get_runtime(current, total, chunk_size, t0, t1):
eta = get_eta(current, total, chunk_size, t1, return_str=False)
total_runtime = time() - t0 + eta
t, unit = utils.time_writer(total_runtime)
return f"{round(t, 4)} {unit}"

def toGPU(graph_data):
"""
Moves "graph_data" from CPU to GPU.
Parameters
----------
graph_data : GraphDataset
Dataset to be moved to GPU.
Returns
-------
x : torch.Tensor
Matrix of node feature vectors.
edge_idx : torch.Tensor
Tensor containing edges in graph.
"""
x = graph_data.x.to("cuda:0", dtype=torch.float32)
edge_index = graph_data.edge_index.to("cuda:0")
return x, edge_index
14 changes: 6 additions & 8 deletions src/deep_neurographs/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ def get_accepted_propoals_blocks(
accepts = dict()
for block_id in blocks:
# Get accepts
preds = threshold_preds(
preds,
idx_to_edge,
low_threshold,
valid_idxs=block_to_idxs[block_id],
)

# Refine accepts wrt structure
if structure_aware:
graph = neurographs[block_id].copy()
accepts[block_id] = get_structure_aware_accepts(
Expand All @@ -49,6 +41,12 @@ def get_accepted_propoals_blocks(
low_threshold=low_threshold,
)
else:
preds = threshold_preds(
preds,
idx_to_edge,
low_threshold,
valid_idxs=block_to_idxs[block_id],
)
accepts[block_id] = preds.keys()
return accepts

Expand Down
26 changes: 26 additions & 0 deletions src/deep_neurographs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,32 @@ def time_writer(t, unit="seconds"):
t, unit = time_writer(t, unit=unit)
return t, unit

def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)
utils.progress_bar(current, total, eta=eta, runtime=runtime)
return cnt + 1, time()


def get_eta(current, total, chunk_size, t0, return_str=True):
chunk_runtime = time() - t0
remaining = total - current
eta = remaining * (chunk_runtime / chunk_size)
t, unit = utils.time_writer(eta)
return f"{round(t, 4)} {unit}" if return_str else eta


def get_runtime(current, total, chunk_size, t0, t1):
eta = get_eta(current, total, chunk_size, t1, return_str=False)
total_runtime = time() - t0 + eta
t, unit = utils.time_writer(total_runtime)
return f"{round(t, 4)} {unit}"

def toGPU(graph_data):
x = graph_data.x.to("cuda:0", dtype=torch.float32)
edge_index = graph_data.edge_index.to("cuda:0")
return x, edge_index


# --- miscellaneous ---
def get_img_bbox(origin, shape):
Expand Down

0 comments on commit 8690b36

Please sign in to comment.