Skip to content

Commit

Permalink
build: combine whole brain graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Jan 18, 2024
1 parent 10021d4 commit aeffc17
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 82 deletions.
46 changes: 22 additions & 24 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@
Routines that extract the irreducible components of a graph.
--define what an irreducible is
leafs : set
Nodes with degreee 1.
junctions : set
Nodes with degree > 2.
edges : dict
Set of edges connecting nodes in leafs and junctions. The keys are
pairs of nodes connected by an edge and values are a dictionary of
attributes.
--define what a branch is
Terminology
------------
Leaf: a node with degree 1.
Junction: a node with degree > 2.
Irreducibles: the irreducibles of a graph G=(V,E) consists of (1) leaf nodes
V_l, (2) junction nodes, and (3)
junction nodes along
Branch: the sequence of nodes between two
"""

Expand Down Expand Up @@ -58,17 +60,20 @@ def get_irreducibles(swc_dict, swc_id=None, prune=True, depth=16, smooth=True):
all irreducibles from the graph of that type.
"""
# Initializations
# Build dense graph
dense_graph = swc_utils.to_graph(swc_dict)
if prune:
dense_graph = prune_short_branches(dense_graph, depth)

# Extract irreducibles
# Extract nodes
leafs, junctions = get_irreducible_nodes(dense_graph, swc_dict)
assert len(leafs) > 0, "Error: swc with no leaf nodes!"
root = None
if len(leafs) == 0:
return False, None

# Extract edges
edges = dict()
nbs = dict()
root = None
for (i, j) in nx.dfs_edges(dense_graph, source=sample(leafs, 1)[0]):
# Check if start of path is valid
if root is None:
Expand All @@ -84,8 +89,8 @@ def get_irreducibles(swc_dict, swc_id=None, prune=True, depth=16, smooth=True):
)
else:
edges[(root, j)] = attrs
nbs = append_value(nbs, root, j)
nbs = append_value(nbs, j, root)
nbs = utils.append_dict_value(nbs, root, j)
nbs = utils.append_dict_value(nbs, j, root)
root = None

# Output
Expand Down Expand Up @@ -196,6 +201,7 @@ def get_leafs(graph):

def __smooth_branch(swc_dict, attrs, edges, nbs, root, j):
attrs["xyz"] = geometry_utils.smooth_branch(np.array(attrs["xyz"]), s=10)
attrs["radius"] = np.array(attrs["radius"])
swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, root, 0)
swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, j, -1)
edges[(root, j)] = attrs
Expand All @@ -213,14 +219,6 @@ def upd_xyz(swc_dict, attrs, edges, nbs, i, start_or_end):
return swc_dict, edges


def append_value(my_dict, key, value):
if key in my_dict.keys():
my_dict[key].append(value)
else:
my_dict[key] = [value]
return my_dict


def upd_branch_endpoint(edges, key, old_xyz, new_xyz):
if all(edges[key]["xyz"][0] == old_xyz):
edges[key]["xyz"][0] = new_xyz
Expand Down
31 changes: 4 additions & 27 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,6 @@ def build_neurograph(
cnt, t1 = report_progress(i, n_components, chunk_size, cnt, t0, t1)
i += 1
print("\n" + f"add_irreducibles(): {time() - t0} seconds")

"""
t0 = time()
start_ids = get_start_ids(swc_dicts)
with ThreadPoolExecutor() as executor:
futures = {
executor.submit(
neurograph.add_immutables, irreducibles[key], swc_dicts[key], key, start_ids[key]): key for key in swc_dicts.keys()
}
wait(futures)
print(f" --> asynchronous - add_irreducibles(): {time() - t0} seconds")
"""
return neurograph


Expand Down Expand Up @@ -290,9 +278,10 @@ def get_irreducibles(
progress_cnt = 1
for i, process in enumerate(as_completed(processes)):
process_id, result = process.result()
irreducibles[process_id] = result
n_nodes += len(result["leafs"]) + len(result["junctions"])
n_edges += len(result["edges"])
if process_id:
irreducibles[process_id] = result
n_nodes += len(result["leafs"]) + len(result["junctions"])
n_edges += len(result["edges"])
if i > progress_cnt * chunk_size:
progress_cnt, t1 = report_progress(
i, n_components, chunk_size, progress_cnt, t0, t1
Expand All @@ -302,15 +291,6 @@ def get_irreducibles(
return irreducibles, n_nodes, n_edges


def get_start_ids(swc_dicts):
start_id = 0
start_ids = dict()
for key in swc_dicts.keys():
start_ids[key] = start_id
start_id += len(swc_dicts[key]["id"])
return start_ids


# -- Utils --
def get_paths(swc_dir):
paths = []
Expand All @@ -320,11 +300,8 @@ def get_paths(swc_dir):


def report_progress(current, total, chunk_size, cnt, t0, t1):
# Compute
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)

# Write results
utils.progress_bar(current, total, eta=eta, runtime=runtime)
return cnt + 1, time()

Expand Down
34 changes: 13 additions & 21 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def init_densegraph(self):
self.densegraph = DenseGraph(self.swc_paths)

# --- Add nodes or edges ---
def add_immutables(self, irreducibles, swc_id, start_id=None):
def add_immutables(self, irreducibles, swc_id):
# Nodes
node_ids = dict()
cur_id = start_id if start_id else len(self.nodes)
cur_id = len(self.nodes)
node_ids, cur_id = self.__add_nodes(
irreducibles, "leafs", node_ids, cur_id, swc_id
)
Expand All @@ -92,27 +92,19 @@ def add_immutables(self, irreducibles, swc_id, start_id=None):
)

# Add edges
"""
edges = irreducibles["edges"]
for i, j in edges.keys():
# Get edge
edge = (node_ids[i], node_ids[j])
xyz = np.array(edges[(i, j)]["xyz"])
radii = np.array(edges[(i, j)]["radius"])
# Add edge
self.immutable_edges.add(frozenset(edge))
for edge, values in irreducibles["edges"].items():
i, j = edge
self.immutable_edges.add(frozenset((node_ids[i], node_ids[j])))
self.add_edge(
node_ids[i], node_ids[j], xyz=xyz, radius=radii, swc_id=swc_id
node_ids[i],
node_ids[j],
radius=values["radius"],
xyz=values["xyz"],
swc_id=swc_id
)
xyz_to_edge = dict((tuple(xyz), edge) for xyz in xyz)
check_xyz = set(xyz_to_edge.keys())
collisions = check_xyz.intersection(set(self.xyz_to_edge.keys()))
if len(collisions) > 0:
for xyz in collisions:
del xyz_to_edge[xyz]
self.xyz_to_edge.update(xyz_to_edge)
"""
for xyz in values["xyz"][::2]:
self.xyz_to_edge[tuple(xyz)] = (i, j)
self.xyz_to_edge[tuple(values["xyz"][-1])] = (i, j)

def __add_nodes(self, nodes, key, node_ids, cur_id, swc_id):
for i in nodes[key].keys():
Expand Down
17 changes: 8 additions & 9 deletions src/deep_neurographs/swc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,17 @@ def fast_parse(contents):
contents, offset = get_contents(contents)
min_id = np.inf
swc_dict = {
"id": np.zeros((len(contents)), dtype=int),
"radius": np.zeros((len(contents)), dtype=float),
"pid": np.zeros((len(contents)), dtype=int),
"xyz": [],
"id": np.zeros((len(contents)), dtype=np.int32),
"radius": np.zeros((len(contents)), dtype=np.float32),
"pid": np.zeros((len(contents)), dtype=np.int32),
"xyz": np.zeros((len(contents), 3), dtype=np.int32),
}
for i, line in enumerate(contents):
parts = line.split()
xyz = read_xyz(parts[2:5], offset=offset)
swc_dict["id"][i] = int(parts[0])
swc_dict["radius"][i] = float(parts[-2])
swc_dict["pid"][i] = int(parts[-1])
swc_dict["xyz"].append(xyz)
swc_dict["id"][i] = parts[0]
swc_dict["radius"][i] = parts[-2]
swc_dict["pid"][i] = parts[-1]
swc_dict["xyz"][i] = read_xyz(parts[2:5], offset=offset)

# Reindex from zero
min_id = np.min(swc_dict["id"])
Expand Down
28 changes: 27 additions & 1 deletion src/deep_neurographs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,32 @@ def remove_key(my_dict, key):
return my_dict


def append_dict_value(my_dict, key, value):
"""
Appends "value" to the list stored at "key".
Parameters
----------
my_dict : dict
Dictionary to be queried.
key : hashable data type
Key to be query.
value : list item type
Value to append to list stored at "key".
Returns
-------
my_dict : dict
Updated dictionary.
"""
if key in my_dict.keys():
my_dict[key].append(value)
else:
my_dict[key] = [value]
return my_dict


# --- os utils ---
def mkdir(path, delete=False):
"""
Expand Down Expand Up @@ -500,7 +526,7 @@ def progress_bar(current, total, bar_length=50, eta=None, runtime=None):
bar = f"[{'=' * progress}{' ' * (bar_length - progress)}]"
eta = f"Time Remaining: {eta}" if eta else ""
runtime = f"Estimated Total Runtime: {runtime}" if runtime else ""
print(f"\r{bar} {n_completed} | {eta} | {runtime}", end="", flush=True)
print(f"\r{bar} {n_completed} | {eta} | {runtime} ", end="", flush=True)


def xor(a, b):
Expand Down

0 comments on commit aeffc17

Please sign in to comment.