diff --git a/src/adapters/methods/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py index 3a8743a3f2..2a716e2acc 100644 --- a/src/adapters/methods/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -21,19 +21,20 @@ def __init__( n_heads: int, input_size: int, config: PrefixTuningConfig, + n_embd_per_head: Optional[int] = None, ): super().__init__() self.n_layers = n_layers self.n_heads = n_heads self.input_size = input_size - self.n_embd_per_head = self.input_size // self.n_heads + self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads self.config = config self.wte = nn.Embedding(self.config.prefix_length, self.input_size) self.control_trans = nn.Sequential( nn.Linear(self.input_size, self.config.bottleneck_size), Activation_Function_Class(self.config.non_linearity.lower()), - nn.Linear(self.config.bottleneck_size, self.n_layers * 2 * self.input_size), + nn.Linear(self.config.bottleneck_size, self.n_layers * 2 * self.n_heads * self.n_embd_per_head), ) self.dropout = nn.Dropout(self.config.dropout) @@ -70,15 +71,18 @@ def __init__( n_heads: int, input_size: int, config: PrefixTuningConfig, + n_embd_per_head: Optional[int] = None, ): super().__init__() self.n_layers = n_layers self.n_heads = n_heads self.input_size = input_size - self.n_embd_per_head = self.input_size // self.n_heads + self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads self.config = config - self.control_trans = nn.Parameter(torch.randn(self.config.prefix_length * self.n_layers * 2 * self.input_size)) + self.control_trans = nn.Parameter( + torch.randn(self.config.prefix_length * self.n_layers * 2 * self.n_heads * self.n_embd_per_head) + ) self.dropout = nn.Dropout(self.config.dropout) @@ -174,6 +178,7 @@ def confirm_prefix(self, prefix_name: str) -> bool: "n_layers": location_config["count"], "n_heads": location_config["n_heads"], "input_size": location_config["input_size"], + "n_embd_per_head": location_config["n_embd_per_head"], } prefix_tuning = PrefixTuningGroup(module_configs, prefix_tuning_config) prefix_tuning.train(self.training) # make sure training mode is consistent @@ -319,6 +324,7 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: self.location_key, n_heads=self.model_config.num_attention_heads, input_size=self.model_config.hidden_size, + n_embd_per_head=getattr(self.model_config, "d_kv", None), # this is currently specific to T5-3B ) self.prefixes[adapter_name] = prefix_id