Skip to content

Commit

Permalink
Refactor batch simulation parameters and backend
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <abdulsamadsid1@gmail.com>
  • Loading branch information
samadpls committed Oct 19, 2024
1 parent 08ce63b commit dd261d8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
4 changes: 2 additions & 2 deletions examples/howto/plot_batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def summary_func(results):
summary_func=summary_func)
simulation_results = batch_simulation.run(param_grid,
n_jobs=n_jobs,
combinations=False,
backend='multiprocessing')
combinations=True,
backend='loky')
# backend='dask' if installed
print("Simulation results:", simulation_results)
###############################################################################
Expand Down
30 changes: 16 additions & 14 deletions hnn_core/batch_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from joblib import Parallel, delayed, parallel_config

from .parallel_backends import JoblibBackend
from .network import Network
from .externals.mne import _validate_type, _check_option
from .dipole import simulate_dipole
Expand Down Expand Up @@ -196,14 +197,14 @@ def run(self, param_grid, return_output=True,
param_combinations = self._generate_param_combinations(
param_grid, combinations)
total_sims = len(param_combinations)
num_sims_per_batch = max(total_sims // self.batch_size, 1)
num_sims_per_batch = max(total_sims // n_jobs, 1)
batch_size = min(self.batch_size, total_sims)

results = []
simulated_data = []
for i in range(batch_size):
start_idx = i * num_sims_per_batch
end_idx = start_idx + num_sims_per_batch
for i in range(0, total_sims, num_sims_per_batch):
start_idx = i
end_idx = min(i + num_sims_per_batch, total_sims)
if i == batch_size - 1:
end_idx = len(param_combinations)
batch_results = self.simulate_batch(
Expand Down Expand Up @@ -269,10 +270,10 @@ def simulate_batch(self, param_combinations, n_jobs=1,
with parallel_config(backend=backend):
res = Parallel(n_jobs=n_jobs, verbose=verbose)(
delayed(self._run_single_sim)(
params) for params in param_combinations)
params, n_jobs) for params in param_combinations)
return res

def _run_single_sim(self, param_values):
def _run_single_sim(self, param_values, n_jobs=1):
"""Run a single simulation.
Parameters
Expand All @@ -296,14 +297,15 @@ def _run_single_sim(self, param_values):
results = {'net': net, 'param_values': param_values}

if self.save_dpl:
dpl = simulate_dipole(net,
tstop=self.tstop,
dt=self.dt,
n_trials=self.n_trials,
record_vsec=self.record_vsec,
record_isec=self.record_isec,
postproc=self.postproc)
results['dpl'] = dpl
with JoblibBackend(n_jobs=n_jobs):
dpl = simulate_dipole(net,
tstop=self.tstop,
dt=self.dt,
n_trials=self.n_trials,
record_vsec=self.record_vsec,
record_isec=self.record_isec,
postproc=self.postproc)
results['dpl'] = dpl

if self.save_spiking:
results['spiking'] = {
Expand Down

0 comments on commit dd261d8

Please sign in to comment.