Skip to content

Commit

Permalink
Move custom head dict out of config (#700)
Browse files Browse the repository at this point in the history
To make the model_config serializable and prevent the error mentioned in
#680 move the costum_heads dictionary out of the config and make it a
separate attribute of the model class.
  • Loading branch information
hSterz authored Jun 20, 2024
1 parent 95d15e2 commit c243478
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
12 changes: 6 additions & 6 deletions src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class ModelWithFlexibleHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._convert_to_flex_head = True
if not hasattr(self.config, "custom_heads"):
self.config.custom_heads = {}
if not hasattr(self, "custom_heads"):
self.custom_heads = {}
self._active_heads = []

def head_type(head_type_str: str):
Expand Down Expand Up @@ -176,7 +176,7 @@ def add_prediction_head_from_config(
head_class = MODEL_HEAD_MAP[head_type]
head = head_class(self, head_name, **config)
self.add_prediction_head(head, overwrite_ok=overwrite_ok, set_active=set_active)
elif head_type in self.config.custom_heads:
elif head_type in self.custom_heads:
# we have to re-add the head type for custom heads
self.add_custom_head(head_type, head_name, overwrite_ok=overwrite_ok, **config)
else:
Expand All @@ -193,7 +193,7 @@ def get_prediction_heads_config(self):
return heads

def register_custom_head(self, identifier, head):
self.config.custom_heads[identifier] = head
self.custom_heads[identifier] = head

@property
def active_head(self) -> Union[str, List[str]]:
Expand Down Expand Up @@ -253,8 +253,8 @@ def set_active_adapters(
)

def add_custom_head(self, head_type, head_name, overwrite_ok=False, set_active=True, **kwargs):
if head_type in self.config.custom_heads:
head = self.config.custom_heads[head_type](self, head_name, **kwargs)
if head_type in self.custom_heads:
head = self.custom_heads[head_type](self, head_name, **kwargs)
# When a build-in head is added as a custom head it does not have the head_type property
if not hasattr(head.config, "head_type"):
head.config["head_type"] = head_type
Expand Down
3 changes: 2 additions & 1 deletion tests/test_adapter_custom_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def test_add_custom_head(self):
def test_save_load_custom_head(self):
model_name = "bert-base-uncased"
model_config = AutoConfig.from_pretrained(model_name)
model_config.custom_heads = {"tag": CustomHead}
model1 = AutoAdapterModel.from_pretrained(model_name, config=model_config)
model2 = AutoAdapterModel.from_pretrained(model_name, config=model_config)
model1.custom_heads = {"tag": CustomHead}
model2.custom_heads = {"tag": CustomHead}
config = {"num_labels": 3, "layers": 2, "activation_function": "tanh"}
model1.add_custom_head(head_type="tag", head_name="custom_head", **config)

Expand Down

0 comments on commit c243478

Please sign in to comment.