Skip to content

Commit def7f3e

Browse files
authored
Merge pull request #372 from AllenInstitute/update/workshop-2024
Update/workshop 2024
2 parents b3662cb + 862fd7f commit def7f3e

29 files changed

+520
-96
lines changed

bmtk/analyzer/ecp.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import h5py
33
import matplotlib.pyplot as plt
44
import numpy as np
5+
from decimal import Decimal
56

67
from bmtk.utils.sonata.config import SonataConfig
78
from bmtk.simulator.utils import simulation_reports
8-
# from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
9+
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
10+
import matplotlib.font_manager as fm
911

1012

1113
def _get_ecp_path(ecp_path=None, config=None, report_name=None):
@@ -55,30 +57,54 @@ def plot_ecp(config_file=None, report_name=None, ecp_path=None, title=None, show
5557
channels = ecp_h5['/ecp/channel_id'][()]
5658
fig, axes = plt.subplots(len(channels), 1)
5759
fig.text(0.04, 0.5, 'channel id', va='center', rotation='vertical')
60+
v_min, v_max = ecp_h5['/ecp/data'][()].min(), ecp_h5['/ecp/data'][()].max()
61+
# print(v_max - v_min)
62+
# exit()
63+
5864
for idx, channel in enumerate(channels):
5965
data = ecp_h5['/ecp/data'][:, idx]
66+
# print(channel, np.min(data), np.max(data))
6067
axes[idx].plot(time_traces, data)
6168
axes[idx].spines["top"].set_visible(False)
6269
axes[idx].spines["right"].set_visible(False)
6370
axes[idx].set_yticks([])
6471
axes[idx].set_ylabel(channel)
72+
axes[idx].set_ylim([v_min, v_max])
6573

6674
if idx+1 != len(channels):
6775
axes[idx].spines["bottom"].set_visible(False)
6876
axes[idx].set_xticks([])
6977
else:
7078
axes[idx].set_xlabel('timestamps (ms)')
71-
# scalebar = AnchoredSizeBar(axes[idx].transData,
72-
# 2.0, '1 mV', 1,
73-
# pad=0,
74-
# borderpad=0,
75-
# # color='b',
76-
# frameon=True,
77-
# # size_vertical=1.001,
78-
# # fontproperties=fontprops
79-
# )
80-
#
81-
# axes[idx].add_artist(scalebar)
79+
80+
81+
if idx == 0:
82+
scale_bar_size = (v_max-v_min)/2.0
83+
scale_bar_label = f'{scale_bar_size:.2E}'
84+
# print(scale_bar_label)
85+
# exit()
86+
fontprops = fm.FontProperties(size='x-small')
87+
88+
scalebar = AnchoredSizeBar(
89+
axes[idx].transData,
90+
size=scale_bar_size,
91+
label=scale_bar_label,
92+
loc='upper right',
93+
pad=0.1,
94+
borderpad=0.5,
95+
sep=5,
96+
# color='b',
97+
frameon=False,
98+
size_vertical=scale_bar_size,
99+
# size_vertical=1.001,
100+
fontproperties=fontprops
101+
)
102+
axes[idx].add_artist(scalebar)
103+
104+
# label = scalebar.txt_label
105+
# label.set_rotation(270.0)
106+
# label.set_verticalalignment('bottom')
107+
# label.set_horizontalalignment('left')
82108

83109
if title:
84110
fig.set_title(title)

bmtk/simulator/bionet/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2121
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2222
#
23-
from bmtk.simulator.bionet.pyfunction_cache import synapse_model, synaptic_weight, cell_model, add_weight_function, model_processing
23+
from bmtk.simulator.bionet.pyfunction_cache import synapse_model, synaptic_weight, cell_model, add_weight_function, model_processing, \
24+
spikes_generator
2425
from bmtk.simulator.bionet.config import Config
2526
from bmtk.simulator.bionet.bionetwork import BioNetwork
2627
from bmtk.simulator.bionet.biosimulator import BioSimulator

bmtk/simulator/bionet/biocell.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from bmtk.simulator.bionet.morphology import Morphology
2828
import six
2929

30+
import neuron
3031
from neuron import h
3132

3233
pc = h.ParallelContext() # object to access MPI methods
@@ -74,9 +75,6 @@ class BioCell(Cell):
7475
def __init__(self, node, population_name, bionetwork):
7576
super(BioCell, self).__init__(node=node, population_name=population_name, network=bionetwork)
7677

77-
# Set up netcon object that can be used to detect and communicate cell spikes.
78-
self.set_spike_detector(bionetwork.spike_threshold)
79-
8078
# Determine number of segments and store a list of all sections.
8179
self._secs = []
8280
self._secs_by_id = []
@@ -105,6 +103,10 @@ def __init__(self, node, population_name, bionetwork):
105103
self._seg_coords = None
106104
self.build_morphology()
107105

106+
# Set up netcon object that can be used to detect and communicate cell spikes.
107+
self.set_spike_detector(bionetwork.spike_threshold)
108+
109+
108110
def build_morphology(self):
109111
morph_base = Morphology.load(hobj=self.hobj, morphology_file=self.morphology_file, cache_seg_props=True)
110112

@@ -126,6 +128,10 @@ def morphology(self):
126128
"""The actual Morphology object instanstiation"""
127129
return self._morphology
128130

131+
@property
132+
def soma(self):
133+
return self.morphology.soma
134+
129135
@property
130136
def seg_coords(self):
131137
"""Coordinates for segments/sections of the morphology, need to make public for ecp, xstim, and other
@@ -144,7 +150,7 @@ def seg_coords(self):
144150
return self.morphology.seg_coords
145151

146152
def set_spike_detector(self, spike_threshold):
147-
nc = h.NetCon(self.hobj.soma[0](0.5)._ref_v, None, sec=self.hobj.soma[0]) # attach spike detector to cell
153+
nc = h.NetCon(self.soma[0](0.5)._ref_v, None, sec=self.soma[0])
148154
nc.threshold = spike_threshold
149155
pc.cell(self.gid, nc) # associate gid with spike detector
150156

@@ -437,18 +443,41 @@ def __init__(self, node, population_name, bionetwork):
437443
self._vecstim = h.VecStim()
438444
self._vecstim.play(self._spike_trains)
439445

440-
self._precell_filter = bionetwork.spont_syns_filter
446+
self._precell_filter = bionetwork.spont_syns_filter_pre
447+
self._postcell_filter = bionetwork.spont_syns_filter_post
441448
assert(isinstance(self._precell_filter, dict))
442449

443-
def _matches_filter(self, src_node):
450+
def _matches_filter(self, src_node, trg_node=None):
444451
"""Check to see if the presynaptic cell matches the criteria specified"""
445452
for k, v in self._precell_filter.items():
453+
# Some key may not show up as node_variable
454+
if k == 'population' and k not in src_node:
455+
key_val = src_node.population_name
456+
else:
457+
key_val = src_node[k]
458+
459+
if isinstance(v, (list, tuple)):
460+
if key_val not in v:
461+
return False
462+
else:
463+
if key_val != v:
464+
return False
465+
466+
trg_node = trg_node or self
467+
for k, v in self._postcell_filter.items():
468+
# Some key may not show up as node_variable
469+
if k == 'population' and k not in trg_node:
470+
key_val = trg_node._node.population_name
471+
else:
472+
key_val = trg_node[k]
473+
446474
if isinstance(v, (list, tuple)):
447-
if src_node[k] not in v:
475+
if key_val not in v:
448476
return False
449477
else:
450-
if src_node[k] != v:
478+
if key_val != v:
451479
return False
480+
452481
return True
453482

454483
def _set_connections(self, edge_prop, src_node, syn_weight, stim=None):

bmtk/simulator/bionet/bionetwork.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def __init__(self):
7777
self._gid_pool = GidPool()
7878

7979
self.has_spont_syns = False
80-
self.spont_syns_filter = None
80+
self.spont_syns_filter_pre = None
81+
self.spont_syns_filter_post = None
8182
self.spont_syns_times = None
8283

8384
@property
@@ -88,7 +89,7 @@ def gid_pool(self):
8889
def py_function_caches(self):
8990
return nrn
9091

91-
def set_spont_syn_activity(self, precell_filter, timestamps):
92+
def set_spont_syn_activity(self, precell_filter, postcell_filter, timestamps):
9293
self._model_type_map = {
9394
'biophysical': BioCellSpontSyn,
9495
'point_process': PointProcessCellSpontSyns,
@@ -98,7 +99,8 @@ def set_spont_syn_activity(self, precell_filter, timestamps):
9899
}
99100

100101
self.has_spont_syns = True
101-
self.spont_syns_filter = precell_filter
102+
self.spont_syns_filter_pre = precell_filter
103+
self.spont_syns_filter_post = postcell_filter
102104
self.spont_syns_times = timestamps
103105

104106
def get_node_id(self, population, node_id):
@@ -134,12 +136,12 @@ def add_nodes(self, node_population):
134136
self._gid_pool.add_pool(node_population.name, node_population.n_nodes())
135137
super(BioNetwork, self).add_nodes(node_population)
136138

137-
def get_virtual_cells(self, population, node_id, spike_trains):
139+
def get_virtual_cells(self, population, node_id, spike_trains, spikes_generator=None, sim=None):
138140
if node_id in self._virtual_nodes[population]:
139141
return self._virtual_nodes[population][node_id]
140142
else:
141143
node = self.get_node_id(population, node_id)
142-
virt_cell = VirtualCell(node, population, spike_trains)
144+
virt_cell = VirtualCell(node, population, spike_trains, spikes_generator, sim)
143145
self._virtual_nodes[population][node_id] = virt_cell
144146
return virt_cell
145147

@@ -151,7 +153,7 @@ def get_disconnected_cell(self, population, node_id, spike_trains):
151153
virt_cell = self._disconnected_source_cells[population][node_id]
152154
else:
153155
node = self.get_node_id(population, node_id)
154-
virt_cell = VirtualCell(node, population, spike_trains)
156+
virt_cell = VirtualCell(node, population, spike_trains, self)
155157
self._disconnected_source_cells[population][node_id] = virt_cell
156158

157159
return virt_cell
@@ -369,7 +371,7 @@ def find_edges(self, source_nodes=None, target_nodes=None):
369371

370372
return selected_edges
371373

372-
def add_spike_trains(self, spike_trains, node_set):
374+
def add_spike_trains(self, spike_trains, node_set, spikes_generator=None, sim=None):
373375
self._init_connections()
374376

375377
src_nodes = [node_pop for node_pop in self.node_populations if node_pop.name in node_set.population_names()]
@@ -379,7 +381,7 @@ def add_spike_trains(self, spike_trains, node_set):
379381
if edge_pop.virtual_connections:
380382
for trg_nid, trg_cell in self._rank_node_ids[edge_pop.target_nodes].items():
381383
for edge in edge_pop.get_target(trg_nid):
382-
src_cell = self.get_virtual_cells(source_population, edge.source_node_id, spike_trains)
384+
src_cell = self.get_virtual_cells(source_population, edge.source_node_id, spike_trains, spikes_generator, sim)
383385
trg_cell.set_syn_connection(edge, src_cell, src_cell)
384386

385387
elif edge_pop.mixed_connections:

bmtk/simulator/bionet/biosimulator.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def from_config(cls, config, network, set_recordings=True):
326326
if sim_input.input_type == 'syn_activity':
327327
network.set_spont_syn_activity(
328328
precell_filter=sim_input.params['precell_filter'],
329+
postcell_filter=sim_input.params.get('postcell_filter', {}),
329330
timestamps=sim_input.params['timestamps']
330331
)
331332

@@ -346,13 +347,30 @@ def from_config(cls, config, network, set_recordings=True):
346347

347348
# TODO: Need to create a gid selector
348349
for sim_input in inputs.from_config(config):
349-
if sim_input.input_type == 'spikes' and sim_input.module in ['nwb', 'csv', 'sonata']:
350+
if sim_input.input_type == 'spikes' and sim_input.module in ['nwb', 'csv', 'sonata', 'h5']:
350351
io.log_info('Building virtual cell stimulations for {}'.format(sim_input.name))
351352
path = sim_input.params['input_file']
352353
spikes = SpikeTrains.load(path=path, file_type=sim_input.module, **sim_input.params)
353354
# node_set_opts = sim_input.params.get('node_set', 'all')
354355
node_set = network.get_node_set(sim_input.node_set)
355-
network.add_spike_trains(spikes, node_set)
356+
network.add_spike_trains(
357+
spike_trains=spikes,
358+
node_set=node_set,
359+
spikes_generator=None,
360+
sim=sim
361+
)
362+
363+
elif sim_input.input_type == 'spikes' and sim_input.module == 'function':
364+
io.log_info('Building virtual cell stimulations for {}'.format(sim_input.name))
365+
# path = sim_input.params.get['input_file']
366+
spikes_generator = sim_input.params['spikes_function']
367+
node_set = network.get_node_set(sim_input.node_set)
368+
network.add_spike_trains(
369+
spike_trains=None,
370+
node_set=node_set,
371+
spikes_generator=spikes_generator,
372+
sim=sim
373+
)
356374

357375
elif sim_input.module == 'IClamp':
358376
sim.add_mod(mods.IClampMod(input_type=sim_input.input_type, **sim_input.params))

bmtk/simulator/bionet/cell.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ def get_connection_info(self):
106106

107107
def set_syn_connections(self, edge_prop, src_node, stim=None):
108108
raise NotImplementedError
109+
110+
def get_section(self, sec_name, sec_index):
111+
raise NotImplementedError
109112

113+
def __contains__(self, node_prop):
114+
return node_prop in self._node
115+
110116
def __getitem__(self, node_prop):
111117
return self._node[node_prop]

bmtk/simulator/bionet/config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ def create_output_dir(self):
3939
io.setup_output_dir(self.output_dir, self.log_file)
4040

4141
def load_nrn_modules(self):
42-
nrn.load_neuron_modules(self.mechanisms_dir, self.templates_dir)
42+
nrn.load_neuron_modules(
43+
mechanisms_dir=self.mechanisms_dir,
44+
templates_dir=self.templates_dir,
45+
default_templates=self.use_default_templates,
46+
use_old_import3d=self.use_old_import3d
47+
)
4348

4449
def build_env(self):
4550
self.io = io
@@ -52,3 +57,8 @@ def build_env(self):
5257

5358
pc.barrier()
5459
self.load_nrn_modules()
60+
61+
def _set_class_props(self):
62+
super(Config, self)._set_class_props()
63+
self.use_old_import3d = self.run.get('use_old_import3d', False)
64+
self.use_default_templates = self.run.get('use_old_import3d', True)

bmtk/simulator/bionet/default_setters/cell_models.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import numpy as np
2525
from neuron import h
26+
import inspect
2627
try:
2728
from sklearn.decomposition import PCA
2829
except Exception as e:
@@ -41,19 +42,46 @@
4142
"""
4243

4344
def loadHOC(cell, template_name, dynamics_params):
44-
# Get template to instantiate
45-
template_call = getattr(h, template_name)
46-
if dynamics_params is not None and 'params' in dynamics_params:
47-
template_params = dynamics_params['params']
48-
if isinstance(template_params, list):
49-
# pass in a list of parameters
50-
hobj = template_call(*template_params)
45+
"""A Generic function for creating a cell object from a NEURON HOC Template (eg. a *.hoc file with
46+
`begintemplate template_name` in header). It essentially tries to guess the correct parameters that need to be
47+
called so may not work the majority of the times and require to be overloaded.
48+
49+
:param cell: A SONATA node object, can be used as a dict to get individual properties of current cell.
50+
:param template_name: name of HOCTemplate as stored in "model_template" attribute (hoc:<template_name>).
51+
:param dynamics_params: Dictionary containing contents of cell['dynamics_params'] as loaded from a json file or hdf5.
52+
If cell does not have "dynamics_params" attributes then will be set to None.
53+
"""
54+
try:
55+
# Get template to instantiate
56+
template_call = getattr(h, template_name)
57+
except AttributeError as ae:
58+
io.log_error(
59+
f'loadHOC was unable to load in Neuron HOC Template "{template_name}, '
60+
'Make sure appropiate .hoc file is stored in templates_dir.'
61+
)
62+
raise ae
63+
64+
try:
65+
if dynamics_params is not None and 'params' in dynamics_params:
66+
template_params = dynamics_params['params']
67+
if isinstance(template_params, list):
68+
# pass in a list of parameters
69+
hobj = template_call(*template_params)
70+
else:
71+
# only a single parameter
72+
hobj = template_call(template_params)
73+
elif cell.morphology_file is not None:
74+
# instantiate template with no parameters
75+
hobj = template_call(cell.morphology_file)
5176
else:
52-
# only a single parameter
53-
hobj = template_call(template_params)
54-
else:
55-
# instantiate template with no parameters
56-
hobj = template_call()
77+
hobj = template_call()
78+
except RuntimeError as rte:
79+
io.log_error(
80+
f'bmtk.simualtor.bionet.default_setters.cell_models.loadHOC function failed to load HOC template "{template_call}". '
81+
'If Hoc Templates requires special call to be initialized consider using `bmtk.simulator.bionet.add_cell_model()` '
82+
'to overwrite this function.'
83+
)
84+
raise rte
5785

5886
# TODO: All "all" section if it doesn't exist
5987
# hobj.all = h.SectionList()

0 commit comments

Comments
 (0)