Skip to content

Commit

Permalink
cleaning transfer script and example. Also extend functionality of pl…
Browse files Browse the repository at this point in the history
…ot script
  • Loading branch information
robban80 committed Oct 3, 2024
1 parent 87debb5 commit 155ba3e
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 216 deletions.
8 changes: 4 additions & 4 deletions examples/simple_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def transfer_all(source_path, new_celltype_path):
try:
celltype = d.split('-')[1]
except:
print()
print(f'celltype can not be extracted from the model name of: {d}')
print('in order to work with batch transfer, model names have to be in the format:')
print('region-type-additional_info, e.g. str-dspn-...')
print(f'\nmodel: {d}')
print('\tcelltype can not be extracted from the model name')
print('\tin order to work with batch transfer, model names have to be in the format:')
print('\n\tregion-type-additional_info\n\te.g. str-dspn-...')
print('--> skipping')
continue
destination = os.path.join(new_celltype_path, celltype, d)
Expand Down
209 changes: 34 additions & 175 deletions examples/transfer_models_from_bluepyopt.ipynb

Large diffs are not rendered by default.

144 changes: 107 additions & 37 deletions examples/verify_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
python simulate.py -p ../../Alex_model_repo/models/optim/HBP-2022Q2/str-dspn-e150602_c1_D1-mWT-0728MSN01-v20220620/ -o ../data/neurons/striatum/test/str-dspn-e150602_c1_D1-mWT-0728MSN01-v20220620/ -s 0 -i 0
'''

def simulate_org_model(model_path, pid, return_tv=True, plot=False, print_psection=False):
def simulate_org_model(model_path, pid, return_tv=True, plot=False, print_psection=False, current_amplitude=None):

# change directory
orgdir = os.getcwd()
Expand Down Expand Up @@ -59,15 +59,19 @@ def simulate_org_model(model_path, pid, return_tv=True, plot=False, print_psecti
# import config files
with open('../config/parameters.json') as fp:
parameters = json.load(fp)
with open('../config/protocols.json') as fp:
protocols = json.load(fp, object_pairs_hook=OrderedDict)

h.v_init = parameters[1]['value']
h.celsius = parameters[0]['value']

proto = list(p for p in protocols if p.startswith('IDthresh_'))[0]
stim0 = protocols[proto]['stimuli'][0]['amp']
stim1 = protocols[proto]['stimuli'][1]['amp']
if current_amplitude:
stim0 = current_amplitude # nA
stim1 = 0
else:
with open('../config/protocols.json') as fp:
protocols = json.load(fp, object_pairs_hook=OrderedDict)

proto = list(p for p in protocols if p.startswith('IDthresh_'))[0]
stim0 = protocols[proto]['stimuli'][0]['amp']
stim1 = protocols[proto]['stimuli'][1]['amp']

simtot = 1500 # hardcoding simlen

Expand Down Expand Up @@ -102,60 +106,88 @@ def simulate_org_model(model_path, pid, return_tv=True, plot=False, print_psecti
if plot:
import matplotlib.pyplot as plt
plt.plot(time, vm)
plt.xlim([0,3000]) # remove this lim? (added for comparison with original model)
plt.show()

if not return_tv:
return 0,0

return np.array(time), np.array(vm)

def simulate_snudda(ref_model_path, transfered_model_path, pid=0, ref_tv=[]):
def simulate_snudda( transfered_model_path,
ref_model_path=None,
hashkey=None,
mkey=None,
pid=0,
ref_tv=[],
sim_len=1.5,
print_psection=False,
current_amplitude=None):

print('2.1 simulating transfered model in snudda...\n')
# get hash key corresponding to model index (pid)
with open(f'{transfered_model_path}temp/parameters_hash_id.json', 'r') as h:
hash2id = json.load(h)
id2hash = {int(item):key for key,item in hash2id.items()} # reverse key:item
hashkey = id2hash[pid]
morph_path = get_morphology_source_file(ref_model_path)
with open(f'{transfered_model_path}morphology/morphology_hash_filename.json', 'r') as h:
mhash2name = json.load(h)
name2mhash = {name:mkey for mkey,name in mhash2name.items()} # reverse key:item
morph = os.path.splitext(os.path.basename(morph_path))[0]
mkey = name2mhash[f'{morph}-var0.swc'] # var0 and the original morphology are identical
if not hashkey:
# get hash key corresponding to model index (pid)
with open(f'{transfered_model_path}temp/parameters_hash_id.json', 'r') as h:
hash2id = json.load(h)
id2hash = {int(item):key for key,item in hash2id.items()} # reverse key:item
hashkey = id2hash[pid]
if not mkey:
morph_path = get_morphology_source_file(transfered_model_path)
with open(f'{transfered_model_path}morphology/morphology_hash_filename.json', 'r') as h:
mhash2name = json.load(h)
name2mhash = {name:mkey for mkey,name in mhash2name.items()} # reverse key:item
morph = os.path.splitext(os.path.basename(morph_path))[0]
mkey = name2mhash[f'{morph}-var0.swc'] # var0 and the original morphology are identical

# model setup ----------------
from snudda import Snudda
network_path = "snudda"
ss = Snudda(network_path=network_path)
ss.init_tiny( neuron_paths=[transfered_model_path],
neuron_names=["random"],
neuron_names=["Cell"],
number_of_neurons=[1],
morphology_key=[mkey],
parameter_key=[hashkey])

ss.create_network()
# current mag and delay from file
with open(f'{ref_model_path}config/protocols.json') as fp:
protocols = json.load(fp, object_pairs_hook=OrderedDict)
proto = list(p for p in protocols if p.startswith('IDthresh_'))[0]
stim0 = protocols[proto]['stimuli'][0]['amp'] * 1e-9
stim1 = protocols[proto]['stimuli'][1]['amp'] * 1e-9

print(stim0, stim1)
# current amplitude and delay from file
if current_amplitude:
stim1=0 * 1e-9
stim0=current_amplitude * 1e-9 # nA to A
elif os.path.isfile(f'{ref_model_path}config/protocols.json'):
with open(f'{ref_model_path}config/protocols.json') as fp:
protocols = json.load(fp, object_pairs_hook=OrderedDict)
proto = list(p for p in protocols if p.startswith('IDthresh_'))[0]
stim0 = protocols[proto]['stimuli'][0]['amp'] * 1e-9
stim1 = protocols[proto]['stimuli'][1]['amp'] * 1e-9
else:
print('NOT ABLE TO OPEN PROTOCOL FILE. Using a standard stimuli of 300 pA')
import time
time.sleep(5)
stim1=0 * 1e-9
stim0=0.3 * 1e-9 # using 300 pA

simulation_config = {"current_injection_info" : {"0": {"time": [0, 0.7, 0.7, 1.5],
simulation_config = {"current_injection_info" : {"0": {"time": [0, 0.7, 0.7, sim_len],
"current": [stim1, stim1, stim1+stim0, stim1+stim0]}}}
# simulate
ss.simulate(simulation_config=simulation_config, verbose=False, time=1.5)
#
sim = ss.simulate(simulation_config=simulation_config, verbose=False, time=sim_len)

if print_psection:
h = sim.sim.neuron.h
for sec in h.allsec():
print(h.psection(sec=sec))


# plot
from snudda.utils import SnuddaLoadSimulation
sls = SnuddaLoadSimulation(network_path=network_path)
time = sls.get_time()
neuron_id = 0
voltage = sls.get_data(neuron_id=neuron_id, data_type="voltage")[0][neuron_id]

print('3. Comparing models...')
vm = voltage.T[0]*1000
t = time*1000

print('3. Comparing models...')
if len(ref_tv):
if all(vm == ref_tv[1]):
print('\t-> models gives identical results')
Expand All @@ -176,8 +208,7 @@ def simulate_snudda(ref_model_path, transfered_model_path, pid=0, ref_tv=[]):
plt.savefig('org_and_snudda.png')
plt.show()




def upgrade_parameters_to_v2(model_path, hashkey, kid=0):
# open transfered parameter file. This file contains many families of models
with open(f'{model_path}parameters.json', 'r') as h:
Expand Down Expand Up @@ -298,7 +329,42 @@ def simulate_transfered_model(ref_model_path, transfered_model_path, pid, upgrad
plt.savefig('org_and_my.png')
plt.show()


def print_hashkey2id(model_path):
import sys
with open(f'{model_path}/temp/parameters_hash_id.json') as fp:
h2id = json.load(fp)
print(h2id)
sys.exit()

def get_amplitude(current_amplitude):
if len(current_amplitude.split()) > 1:
raise ValueError(f'current_amplitude must either be given as a number (in nA)\nor a string with units (A/nA/pA), without space. E.g: 100pA\n\tnot {current_amplitude}')
elif 'nA' in current_amplitude:
print(f"nano ampere... {current_amplitude.split('nA')[0]}")
return float(current_amplitude.split('nA')[0])
elif 'pA' in current_amplitude:
return float(current_amplitude.split('pA')[0]) * 1e-3
elif 'A' in current_amplitude:
return float(current_amplitude.split('A')[0]) * 1e9
else:
raise ValueError(f'current_amplitude must either be given as a number (in nA)\nor a string with units (A/nA/pA), without space. E.g: 100pA\n\tnot {current_amplitude}')


def main_compare(ref_model, trans_model, mid=0, print_psection=False, current_amplitude=False):
# can be used to compare models with a single command
amp = current_amplitude
if current_amplitude:
if type(current_amplitude)==str:
# try to extract value from string
amp = get_amplitude(current_amplitude)

ref_tv = simulate_org_model( ref_model,
mid,
print_psection=print_psection,
return_tv=True,
current_amplitude=amp)

simulate_snudda(trans_model, ref_model_path=ref_model, pid=mid, ref_tv=ref_tv, print_psection=print_psection, current_amplitude=amp)

if __name__ == '__main__':
import argparse
Expand All @@ -310,13 +376,17 @@ def simulate_transfered_model(ref_model_path, transfered_model_path, pid, upgrad
parser.add_argument('-v','--plot', help='plot voltage of reference model by itself? (default False)', default=0)
parser.add_argument('-r','--return_tv', help='return time and voltage? Needed for comparison (default True)', default=1)
parser.add_argument('-u','--upgrade', help='upgrade params--must be done if not done before (default False)', default=0)
parser.add_argument( '--print_hashkeys', help='print all hashkeys:id combinations in the transfered param file and exit', action="store_true", default=False)

args = vars(parser.parse_args())

if args['print_hashkey']:
print_hashkey2id(args['out'])

if not args['psprint']:
pps = False
else:
pps = True


t,v = simulate_org_model( args['path'],
int(args['mid']),
Expand Down

0 comments on commit 155ba3e

Please sign in to comment.