Skip to content

Commit

Permalink
Allow setting list of output embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Sep 17, 2023
1 parent 87111aa commit 3d0d283
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/transformers/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def _init_head_modules(self):

# The following methods are required for handling LM heads

def get_output_embeddings(self):
def get_output_embeddings(self) -> Union[nn.Module, List[nn.Module]]:
# Only gets the output embeddings for the currently active head
embeddings = []
for head_name in self._active_heads:
Expand All @@ -523,12 +523,14 @@ def get_output_embeddings(self):
else:
return embeddings

def set_output_embeddings(self, new_embeddings):
def set_output_embeddings(self, new_embeddings: Union[nn.Module, List[nn.Module]]):
# Only sets the output embeddings for the currently active head
for head_name in self._active_heads:
if not isinstance(new_embeddings, list):
new_embeddings = [new_embeddings] * len(self._active_heads)
for head_name, emb in zip(self._active_heads, new_embeddings):
if head_name in self.heads:
head = self.heads[head_name]
head.set_output_embeddings(new_embeddings)
head.set_output_embeddings(emb)

def tie_weights(self):
"""
Expand Down

0 comments on commit 3d0d283

Please sign in to comment.