Skip to content

Commit

Permalink
clean traces
Browse files Browse the repository at this point in the history
  • Loading branch information
tathey1 committed Jan 11, 2024
1 parent f7704a7 commit 5fa1c2f
Show file tree
Hide file tree
Showing 2 changed files with 332 additions and 0 deletions.
2 changes: 2 additions & 0 deletions brainlit/utils/Neuron_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,8 @@ def _read_swc(
branch = int(name[idx])
except ValueError:
pass
elif len(line) == 0:
pass
elif line[0] != "#":
in_header = False
header_length += 1
Expand Down
330 changes: 330 additions & 0 deletions experiments/sriram/clean-traces.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,330 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import matplotlib.pyplot as plt\n",
"from cloudvolume import CloudVolume\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"from brainlit.algorithms.trace_analysis.fit_spline import GeometricGraph\n",
"from brainlit.utils.Neuron_trace import NeuronTrace\n",
"from scipy.interpolate import splev\n",
"import numpy as np\n",
"import networkx as nx\n",
"from scipy.spatial import cKDTree\n",
"\n",
"%matplotlib qt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## convert to SWC"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
"def skel_to_graph(skel):\n",
" G = nx.Graph()\n",
" for v_n, vertex in enumerate(skel.vertices):\n",
" G.add_node(v_n, loc=vertex)\n",
"\n",
" edges = [(e[0], e[1]) for e in skel.edges]\n",
" G.add_edges_from(edges)\n",
"\n",
" return G\n",
"\n",
"\n",
"def smooth_graph(G):\n",
" new_locs = {}\n",
" av_count = 0\n",
"\n",
" all_locs = []\n",
" for node in G.nodes:\n",
" all_locs.append(G.nodes[node][\"loc\"])\n",
" all_locs = np.array(all_locs)\n",
" kdt = cKDTree(all_locs)\n",
" # kdt.query_ball_point()\n",
"\n",
" for node in G.nodes:\n",
" deg = G.degree(node)\n",
" if deg == 2:\n",
" # nbrs = kdt.query_ball_point(G.nodes[node]['loc'], r=10000)\n",
" # nbrs = G.neighbors(node)\n",
" nbrs = nx.dfs_tree(G, source=node, depth_limit=20)\n",
" locs = [G.nodes[n][\"loc\"] for n in nbrs]\n",
" dists = [\n",
" np.linalg.norm(np.subtract(loc, G.nodes[node][\"loc\"])) for loc in locs\n",
" ]\n",
" weights = [1 if dist < 25000 else 0 for dist in dists]\n",
" # locs += [G.nodes[node]['loc']]\n",
" locs = np.array(locs)\n",
" # weights += [1]\n",
" if np.sum(weights) > 1:\n",
" av_count += 1\n",
"\n",
" new_loc = np.average(locs, axis=0, weights=weights)\n",
" new_locs[node] = new_loc\n",
"\n",
" for node in new_locs.keys():\n",
" G.nodes[node][\"loc\"] = new_locs[node]\n",
"\n",
" print(f\"{av_count} averaged nodes\")\n",
" return G\n",
"\n",
"\n",
"def graph_to_vertices(G):\n",
" vertices = []\n",
" for node in G.nodes:\n",
" vertices.append(G.nodes[node][\"loc\"])\n",
" return np.array(vertices)"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3 averaged nodes\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'loc'</span>: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">([</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6000.2</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6000.4</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6000.2</span><span style=\"font-weight: bold\">])}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\u001b[32m'loc'\u001b[0m: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m6000.2\u001b[0m, \u001b[1;36m6000.4\u001b[0m, \u001b[1;36m6000.2\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"G = nx.Graph()\n",
"G.add_node(0, loc=[0, 0, 0])\n",
"G.add_node(1, loc=[1, 2, 1])\n",
"G.add_node(2, loc=[10000, 10000, 10000])\n",
"G.add_node(3, loc=[10000, 10000, 10000])\n",
"G.add_node(4, loc=[10000, 10000, 10000])\n",
"G.add_edge(0, 1)\n",
"G.add_edge(1, 2)\n",
"G.add_edge(2, 3)\n",
"G.add_edge(3, 4)\n",
"\n",
"G = smooth_graph(G)\n",
"G.nodes[1]"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"4216 averaged nodes\n",
"1\n",
"9498 averaged nodes\n",
"2\n",
"9705 averaged nodes\n",
"3\n",
"10974 averaged nodes\n",
"4\n",
"14124 averaged nodes\n",
"5\n",
"6094 averaged nodes\n",
"6\n",
"9523 averaged nodes\n",
"7\n",
"7 invalid for 220-p29-brain1\n",
"0\n",
"5411 averaged nodes\n",
"1\n",
"6990 averaged nodes\n",
"2\n",
"7081 averaged nodes\n",
"3\n",
"6953 averaged nodes\n",
"4\n",
"6568 averaged nodes\n",
"5\n",
"7419 averaged nodes\n",
"6\n",
"6685 averaged nodes\n",
"7\n",
"6720 averaged nodes\n",
"8\n",
"5809 averaged nodes\n",
"9\n",
"9 invalid for 220-p29-brain2\n",
"0\n",
"846 averaged nodes\n",
"1\n",
"3469 averaged nodes\n",
"2\n",
"2219 averaged nodes\n",
"3\n",
"2578 averaged nodes\n",
"4\n",
"1875 averaged nodes\n",
"5\n",
"5 invalid for adipo-brain1-im3\n"
]
}
],
"source": [
"dir = Path(\"/Users/thomasathey/Documents/mimlab/mouselight/kolodkin/sriram/misc\")\n",
"subdirs = [\"220-p29-brain1\", \"220-p29-brain2\", \"adipo-brain1-im3\"]\n",
"\n",
"for subdir in subdirs:\n",
" trace_dir = dir / subdir / \"traces\"\n",
" vol = CloudVolume(\"precomputed://file://\" + str(trace_dir))\n",
" for skel_id in range(10):\n",
" print(skel_id)\n",
" try:\n",
" skel = vol.skeleton.get(skel_id)\n",
" G = skel_to_graph(skel)\n",
" G = smooth_graph(G)\n",
" vertices = graph_to_vertices(G)\n",
" skel.vertices = vertices\n",
" skel.vertex_types = skel.radii\n",
" txt = skel.to_swc()\n",
" with open(dir / subdir / f\"{skel_id}_smoothed.swc\", \"w\") as f:\n",
" f.write(txt)\n",
" except:\n",
" print(f\"{skel_id} invalid for {subdir}\")\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/thomasathey/Documents/mimlab/mouselight/brainlit_parent/brainlit/brainlit/utils/Neuron_trace.py:743: UserWarning: No offset information found in: /Users/thomasathey/Documents/mimlab/mouselight/kolodkin/sriram/misc/220-p29-brain1/1_smoothed.swc\n",
" warnings.warn(\"No offset information found in: \" + path)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded segment 1\n",
"Fitting spline tree\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/thomasathey/Documents/mimlab/mouselight/brainlit_parent/brainlit/brainlit/algorithms/trace_analysis/fit_spline.py:234: UserWarning: There are still duplicate locations after removing connected duplicates.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"plotting\n"
]
}
],
"source": [
"dir = Path(\"/Users/thomasathey/Documents/mimlab/mouselight/kolodkin/sriram/misc\")\n",
"subdirs = [\"220-p29-brain1\", \"220-p29-brain2\", \"adipo-brain1-im3\"]\n",
"\n",
"subdir_choice = subdirs[0]\n",
"skel_id_choice = 1\n",
"\n",
"for subdir in subdirs:\n",
" if subdir != subdir_choice:\n",
" continue\n",
" trace_dir = dir / subdir\n",
" for skel_id in range(10):\n",
" if skel_id != skel_id_choice:\n",
" continue\n",
"\n",
" swc_path = trace_dir / f\"{skel_id}_smoothed.swc\"\n",
" swc_trace = NeuronTrace(path=str(swc_path))\n",
" df_swc_offset_neuron = swc_trace.get_df()\n",
"\n",
" print(\"Loaded segment {}\".format(skel_id))\n",
" G = GeometricGraph(df=df_swc_offset_neuron, remove_duplicates=True)\n",
" print(f\"Fitting spline tree\")\n",
" spline_tree = G.fit_spline_tree_invariant()\n",
" print(\"plotting\")\n",
" ax = plt.figure().add_subplot(projection=\"3d\")\n",
"\n",
" for j, node in enumerate(spline_tree.nodes):\n",
" spline = spline_tree.nodes[node]\n",
" tck, u_um = spline[\"spline\"]\n",
" y = splev(u_um, tck)\n",
"\n",
" ax.plot(y[0], y[1], y[2], c=\"blue\", linewidth=0.5)\n",
"\n",
" # ax.set_axis_off()\n",
" plt.show()\n",
" break\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "docs_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 5fa1c2f

Please sign in to comment.