Skip to content

Commit

Permalink
Fix Prefix-Tuning for T5 models where d_kv != d_model / num_heads
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Dec 18, 2023
1 parent c921726 commit a48eb7f
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@ def __init__(
n_heads: int,
input_size: int,
config: PrefixTuningConfig,
n_embd_per_head: Optional[int] = None,
):
super().__init__()
self.n_layers = n_layers
self.n_heads = n_heads
self.input_size = input_size
self.n_embd_per_head = self.input_size // self.n_heads
self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads
self.config = config

self.wte = nn.Embedding(self.config.prefix_length, self.input_size)
self.control_trans = nn.Sequential(
nn.Linear(self.input_size, self.config.bottleneck_size),
Activation_Function_Class(self.config.non_linearity.lower()),
nn.Linear(self.config.bottleneck_size, self.n_layers * 2 * self.input_size),
nn.Linear(self.config.bottleneck_size, self.n_layers * 2 * self.n_heads * self.n_embd_per_head),
)
self.dropout = nn.Dropout(self.config.dropout)

Expand Down Expand Up @@ -70,15 +71,18 @@ def __init__(
n_heads: int,
input_size: int,
config: PrefixTuningConfig,
n_embd_per_head: Optional[int] = None,
):
super().__init__()
self.n_layers = n_layers
self.n_heads = n_heads
self.input_size = input_size
self.n_embd_per_head = self.input_size // self.n_heads
self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads
self.config = config

self.control_trans = nn.Parameter(torch.randn(self.config.prefix_length * self.n_layers * 2 * self.input_size))
self.control_trans = nn.Parameter(
torch.randn(self.config.prefix_length * self.n_layers * 2 * self.n_heads * self.n_embd_per_head)
)

self.dropout = nn.Dropout(self.config.dropout)

Expand Down Expand Up @@ -174,6 +178,7 @@ def confirm_prefix(self, prefix_name: str) -> bool:
"n_layers": location_config["count"],
"n_heads": location_config["n_heads"],
"input_size": location_config["input_size"],
"n_embd_per_head": location_config["n_embd_per_head"],
}
prefix_tuning = PrefixTuningGroup(module_configs, prefix_tuning_config)
prefix_tuning.train(self.training) # make sure training mode is consistent
Expand Down Expand Up @@ -319,6 +324,7 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
self.location_key,
n_heads=self.model_config.num_attention_heads,
input_size=self.model_config.hidden_size,
n_embd_per_head=getattr(self.model_config, "d_kv", None), # this is currently specific to T5-3B
)
self.prefixes[adapter_name] = prefix_id

Expand Down

0 comments on commit a48eb7f

Please sign in to comment.