Skip to content
This repository has been archived by the owner on Jan 5, 2024. It is now read-only.

Commit

Permalink
Merge pull request #23 from brain-score/cornets-hooks
Browse files Browse the repository at this point in the history
integrate hooks into CORnets; make hooks generic by accepting model input
  • Loading branch information
mschrimpf authored Jul 16, 2019
2 parents 224d4d0 + 8dea21f commit 4c80070
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
4 changes: 0 additions & 4 deletions candidate_models/base_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,6 @@ def __init__(self):
multiplier=multiplier: TFSlimModel.init(
identifier, preprocessing_type='inception', image_size=image_size, net_name=net_name,
model_ctr_kwargs={'depth_multiplier': multiplier})
# CORnets
for cornet_type in ['Z', 'R', 'R2', 'S']:
identifier = f"CORnet-{cornet_type}"
_key_functions[identifier] = lambda identifier=identifier: cornet(identifier)

# instantiate models with LazyLoad wrapper
for identifier, function in _key_functions.items():
Expand Down
52 changes: 40 additions & 12 deletions candidate_models/model_commitments/cornets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from brainscore.model_interface import BrainModel
from brainscore.utils import LazyLoad
from candidate_models.base_models import cornet
from candidate_models.model_commitments.ml_pool import Hooks
from candidate_models.utils import UniqueKeyDict
from model_tools.brain_transformation.behavior import BehaviorArbiter, LogitsBehavior, ProbabilitiesMapping
from result_caching import store

Expand Down Expand Up @@ -263,15 +265,41 @@ def cornet_r2_brainmodel():
})


cornet_brain_pool = {
'CORnet-Z': LazyLoad(cornet_z_brainmodel),
'CORnet-S': LazyLoad(cornet_s_brainmodel),
'CORnet-S101010': LazyLoad(cornet_s101010_brainmodel),
'CORnet-S222': LazyLoad(cornet_s222_brainmodel),
'CORnet-S444': LazyLoad(cornet_s444_brainmodel),
'CORnet-S484': LazyLoad(cornet_s484_brainmodel),
'CORnet-S10rep': LazyLoad(cornet_s10rep_brainmodel),
'CORnet-R': LazyLoad(cornet_r_brainmodel),
'CORnet-R10rep': LazyLoad(cornet_r10rep_brainmodel),
'CORnet-R2': LazyLoad(cornet_r2_brainmodel),
}
class CORnetBrainPool(UniqueKeyDict):
def __init__(self):
super(CORnetBrainPool, self).__init__()

model_pool = {
'CORnet-Z': LazyLoad(cornet_z_brainmodel),
'CORnet-S': LazyLoad(cornet_s_brainmodel),
'CORnet-S101010': LazyLoad(cornet_s101010_brainmodel),
'CORnet-S222': LazyLoad(cornet_s222_brainmodel),
'CORnet-S444': LazyLoad(cornet_s444_brainmodel),
'CORnet-S484': LazyLoad(cornet_s484_brainmodel),
'CORnet-S10rep': LazyLoad(cornet_s10rep_brainmodel),
'CORnet-R': LazyLoad(cornet_r_brainmodel),
'CORnet-R10rep': LazyLoad(cornet_r10rep_brainmodel),
'CORnet-R2': LazyLoad(cornet_r2_brainmodel),
}

self._accessed_brain_models = []

for basemodel_identifier, brain_model in model_pool.items():
activations_model = LazyLoad(lambda brain_model=brain_model: brain_model.activations_model)
for identifier, activations_model in Hooks().iterate_hooks(basemodel_identifier, activations_model):
def load(basemodel_identifier=basemodel_identifier, identifier=identifier, brain_model=brain_model):
# only update when actually required, otherwise we'd change the activations_model
# of one brain_model at all times
if basemodel_identifier in self._accessed_brain_models:
raise ValueError(f"{identifier}'s brain-model {basemodel_identifier} has already been accessed "
f"in this session. To avoid clashes in the hooks, "
f"please run {identifier} in a separate session.")
self._accessed_brain_models.append(basemodel_identifier)
# upon accessing the `activations_model`, the Hook will automatically
# attach to the `brain_model.activations_model`.
return brain_model

self[identifier] = LazyLoad(load)


cornet_brain_pool = CORnetBrainPool()
15 changes: 7 additions & 8 deletions candidate_models/model_commitments/ml_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@ def __init__(self):
"degrees": lambda activations_model: PixelsToDegrees.hook(
activations_model, target_pixels=activations_model.image_size)}

def iterate_hooks(self, basemodel_identifier):
for hook_identifiers in itertools.chain(
*[itertools.combinations(self.activation_hooks, n) for n in range(len(self.activation_hooks) + 1)]):
def iterate_hooks(self, basemodel_identifier, activations_model):
for hook_identifiers in itertools.chain.from_iterable(
itertools.combinations(self.activation_hooks, n) for n in range(len(self.activation_hooks) + 1)):
hook_identifiers = list(sorted(hook_identifiers))
identifier = basemodel_identifier
if len(hook_identifiers) > 0:
identifier += self.HOOK_SEPARATOR + "-".join(hook_identifiers)

# enforce early parameter binding: https://stackoverflow.com/a/3431699/2225200
def load(basemodel_identifier=basemodel_identifier, identifier=identifier,
def load(identifier=identifier, activations_model=activations_model,
hook_identifiers=hook_identifiers):
activations_model = base_model_pool[basemodel_identifier]
activations_model.identifier = identifier # since inputs are different, also change identifier
for hook in hook_identifiers:
self.activation_hooks[hook](activations_model)
Expand Down Expand Up @@ -217,7 +216,7 @@ def __init__(self):
continue
layers = model_layers[basemodel_identifier]

for identifier, activations_model in Hooks().iterate_hooks(basemodel_identifier):
for identifier, activations_model in Hooks().iterate_hooks(basemodel_identifier, activations_model):
self[identifier] = {'model': activations_model, 'layers': layers}


Expand All @@ -235,13 +234,13 @@ class MLBrainPool(UniqueKeyDict):
def __init__(self):
super(MLBrainPool, self).__init__()

for basemodel_identifier in base_model_pool:
for basemodel_identifier, activations_model in base_model_pool.items():
if basemodel_identifier not in model_layers:
warnings.warn(f"{basemodel_identifier} not found in model_layers")
continue
layers = model_layers[basemodel_identifier]

for identifier, activations_model in Hooks().iterate_hooks(basemodel_identifier):
for identifier, activations_model in Hooks().iterate_hooks(basemodel_identifier, activations_model):
if identifier in self: # already pre-defined
continue

Expand Down

0 comments on commit 4c80070

Please sign in to comment.