From 3d0d283bc4bc4fb6bcf513545f7e05e25df2b6a3 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:43:40 +0200 Subject: [PATCH] Allow setting list of output embeddings --- src/transformers/adapters/heads/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/adapters/heads/base.py b/src/transformers/adapters/heads/base.py index 338f00d6f..0c9caef43 100644 --- a/src/transformers/adapters/heads/base.py +++ b/src/transformers/adapters/heads/base.py @@ -507,7 +507,7 @@ def _init_head_modules(self): # The following methods are required for handling LM heads - def get_output_embeddings(self): + def get_output_embeddings(self) -> Union[nn.Module, List[nn.Module]]: # Only gets the output embeddings for the currently active head embeddings = [] for head_name in self._active_heads: @@ -523,12 +523,14 @@ def get_output_embeddings(self): else: return embeddings - def set_output_embeddings(self, new_embeddings): + def set_output_embeddings(self, new_embeddings: Union[nn.Module, List[nn.Module]]): # Only sets the output embeddings for the currently active head - for head_name in self._active_heads: + if not isinstance(new_embeddings, list): + new_embeddings = [new_embeddings] * len(self._active_heads) + for head_name, emb in zip(self._active_heads, new_embeddings): if head_name in self.heads: head = self.heads[head_name] - head.set_output_embeddings(new_embeddings) + head.set_output_embeddings(emb) def tie_weights(self): """