Skip to content

Commit

Permalink
Merge pull request #133 from neurolib-dev/fix/multimodel_fixes
Browse files Browse the repository at this point in the history
Small MultiModel fixes
  • Loading branch information
caglorithm authored Feb 10, 2021
2 parents 2b8f847 + f0ccf14 commit 3edcaa2
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
12 changes: 6 additions & 6 deletions neurolib/models/multimodel/builder/base/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,10 @@ def update_params(self, params_dict):
local_delays = params_dict.pop(NODE_DELAYS, None)
if local_connectivity is not None and isinstance(local_connectivity, np.ndarray):
assert local_connectivity.shape == self.connectivity.shape
self.connectivity = local_connectivity
self.connectivity = local_connectivity.astype(np.floating)
if local_delays is not None and isinstance(local_delays, np.ndarray):
assert local_delays.shape == self.delays.shape
self.delays = local_delays
self.delays = local_delays.astype(np.floating)
super().update_params(params_dict)

def _sync(self):
Expand Down Expand Up @@ -626,8 +626,8 @@ def get_nested_params(self):
nested_dict = {self.label: {}}
for node in self:
nested_dict[self.label].update(node.get_nested_params())
nested_dict[NETWORK_CONNECTIVITY] = self.connectivity
nested_dict[NETWORK_DELAYS] = self.delays
nested_dict[self.label][NETWORK_CONNECTIVITY] = self.connectivity
nested_dict[self.label][NETWORK_DELAYS] = self.delays
return nested_dict

def init_network(self, **kwargs):
Expand Down Expand Up @@ -666,10 +666,10 @@ def update_params(self, params_dict):
self.nodes[node_index].update_params(node_params)
elif NETWORK_CONNECTIVITY == node_key:
assert node_params.shape == self.connectivity.shape
self.connectivity = node_params
self.connectivity = node_params.astype(np.floating)
elif NETWORK_DELAYS == node_key:
assert node_params.shape == self.delays.shape
self.delays = node_params
self.delays = node_params.astype(np.floating)
else:
logging.warning(f"Not sure what to do with {node_key}...")

Expand Down
9 changes: 8 additions & 1 deletion neurolib/models/multimodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ...utils.collections import dotdict, flat_dict_to_nested, flatten_nested_dict, star_dotdict
from ..model import Model
from .builder.base.constants import NETWORK_CONNECTIVITY, NETWORK_DELAYS
from .builder.base.network import Network, Node

# default run parameters for MultiModels
Expand Down Expand Up @@ -65,13 +66,19 @@ def _set_model_params(self):
Set all necessary model parameters.
"""
params = star_dotdict(flatten_nested_dict(self.model_instance.get_nested_params()))
# all matrices to floats
for k, v in params.items():
if isinstance(v, np.ndarray):
params[k] = v.astype(np.floating)
params.update(DEFAULT_RUN_PARAMS)
params["name"] = self.model_instance.label
params["description"] = self.model_instance.name
if isinstance(self.model_instance, Node):
params.update({"N": 1, "Cmat": np.zeros((1, 1))})
else:
params.update({"N": len(self.model_instance.nodes), "Cmat": self.model_instance.connectivity})
params.update(
{"N": len(self.model_instance.nodes), "Cmat": self.model_instance.connectivity.astype(np.floating)}
)
return params

def getMaxDelay(self):
Expand Down
1 change: 0 additions & 1 deletion tests/multimodel/base/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def test_update_params(self):
net.update_params(
{"connectivity": UPDATE_CONNECTIVITY, "delays": UPDATE_DELAYS, "test_node_0": {f"{EXC}_0": UPDATE_WITH}}
)
print(net.get_nested_params())
np.testing.assert_equal(net.connectivity, UPDATE_CONNECTIVITY)
np.testing.assert_equal(net.delays, UPDATE_DELAYS)
self.assertEqual(net[0][0].params["a"], UPDATE_WITH["a"])
Expand Down
2 changes: 1 addition & 1 deletion tests/multimodel/test_wilson_cowan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
SEED = 42
DURATION = 100.0
DT = 0.01
CORR_THRESHOLD = 0.9
CORR_THRESHOLD = 0.75
NEUROLIB_VARIABLES_TO_TEST = [("q_mean_EXC", "exc"), ("q_mean_INH", "inh")]

# dictionary as backend name: format in which the noise is passed
Expand Down

0 comments on commit 3edcaa2

Please sign in to comment.