Skip to content

Commit

Permalink
Fix iter_layers & add tests
Browse files Browse the repository at this point in the history
iter_layers and iter_attentions shouldn't use `getattr` and `hasattr` but instead use `multigetattr` and `multihasattr`
  • Loading branch information
lenglaender committed Aug 29, 2024
1 parent 9cef6c6 commit 2171409
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 21 deletions.
22 changes: 15 additions & 7 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@
from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool
from .methods.prompt_tuning import PromptTuningLayer
from .methods.reft import init_reft
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc, multigetattr, patch_forward
from .utils import (
EMBEDDING_FILE,
TOKENIZER_PATH,
get_adapter_config_hash,
inherit_doc,
multigetattr,
multihasattr,
patch_forward,
)
from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config


Expand Down Expand Up @@ -1469,18 +1477,18 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr
# Adapter Interface Methods

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(getattr(self, self.adapter_interface.model_layers)):
for i, layer in enumerate(multigetattr(self, self.adapter_interface.model_layers)):
yield i, layer

def get_layer(self, idx: int) -> nn.Module:
return getattr(self, self.adapter_interface.model_layers)[idx]
return multigetattr(self, self.adapter_interface.model_layers)[idx]

def iter_attentions(self) -> Iterable[Tuple[int, Literal["self", "cross"], nn.Module]]:
for i, layer in self.iter_layers():
if hasattr(layer, self.adapter_interface.layer_self_attn or ""):
yield i, "self", getattr(layer, self.adapter_interface.layer_self_attn)
if hasattr(layer, self.adapter_interface.layer_cross_attn or ""):
yield i, "cross", getattr(layer, self.adapter_interface.layer_cross_attn)
if multihasattr(layer, self.adapter_interface.layer_self_attn or ""):
yield i, "self", multigetattr(layer, self.adapter_interface.layer_self_attn)
if multihasattr(layer, self.adapter_interface.layer_cross_attn or ""):
yield i, "cross", multigetattr(layer, self.adapter_interface.layer_cross_attn)

def iter_layer_ffns(self) -> Iterable[Tuple[int, Literal["intermediate", "output"], nn.Module]]:
for i, layer in self.iter_layers():
Expand Down
10 changes: 10 additions & 0 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ def multigetattr(o: object, name: str, default=None) -> Optional[object]:
return o


def multihasattr(o: object, name: str) -> bool:
parts = name.split(".")
for n in parts:
if hasattr(o, n):
o = getattr(o, n)
else:
return False
return True


def multisetattr(o: object, name: str, value: object):
parts = name.split(".")
for n in parts[:-1]:
Expand Down
58 changes: 44 additions & 14 deletions tests/test_custom_interface_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@
import tempfile
import unittest

from parameterized import parameterized

import torch

import adapters
from adapters import AdapterModelInterface, AutoAdapterModel
from adapters.utils import WEIGHTS_NAME
from transformers import AutoModelForCausalLM, LlamaConfig
from transformers import AutoModel, AutoModelForCausalLM, BertConfig, LlamaConfig
from transformers.testing_utils import require_torch, torch_device

from .test_adapter import ids_tensor, make_config


@require_torch
class CustomInterfaceCompatTest(unittest.TestCase):
config = make_config(
# This test is to check if the custom interface produces the same results as the AdapterModel implementation.

llama_config = make_config(
LlamaConfig,
hidden_size=32,
num_hidden_layers=5,
Expand All @@ -24,7 +28,16 @@ class CustomInterfaceCompatTest(unittest.TestCase):
hidden_act="gelu",
pad_token_id=0,
)
adapter_interface = AdapterModelInterface(
bert_config = make_config(
BertConfig,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
pad_token_id=0,
)
llama_adapter_interface = AdapterModelInterface(
adapter_types=["lora", "reft"],
model_embeddings="embed_tokens",
model_layers="layers",
Expand All @@ -36,18 +49,30 @@ class CustomInterfaceCompatTest(unittest.TestCase):
layer_intermediate_proj="mlp.up_proj",
layer_output_proj="mlp.down_proj",
)
bert_adapter_interface = AdapterModelInterface(
adapter_types=["lora", "reft"],
model_embeddings="embeddings",
model_layers="encoder.layer",
layer_self_attn="attention.self",
layer_cross_attn=None,
attn_k_proj="key",
attn_q_proj="query",
attn_v_proj="value",
layer_intermediate_proj="intermediate.dense",
layer_output_proj="output.dense",
)

def create_twin_models(self):
model1 = AutoModelForCausalLM.from_config(self.config())
adapters.init(model1, interface=self.adapter_interface)
def create_twin_models(self, config, adapter_interface, hf_auto_model_class):
model1 = hf_auto_model_class.from_config(config())
adapters.init(model1, interface=adapter_interface)
model1.eval()
# create a twin initialized with the same random weights
model2 = AutoAdapterModel.from_pretrained(None, config=self.config(), state_dict=model1.state_dict())
model2 = AutoAdapterModel.from_pretrained(None, config=config(), state_dict=model1.state_dict())
model2.eval()
return model1, model2

def run_load_test(self, adapter_config):
custom_model, auto_model = self.create_twin_models()
def run_load_test(self, adapter_config, config, adapter_interface, hf_auto_model_class):
custom_model, auto_model = self.create_twin_models(config, adapter_interface, hf_auto_model_class)

name = "dummy_adapter"
custom_model.add_adapter(name, config=adapter_config)
Expand Down Expand Up @@ -79,8 +104,13 @@ def run_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4))

def test_load_lora(self):
self.run_load_test(adapters.LoRAConfig())

def test_load_reft(self):
self.run_load_test(adapters.LoReftConfig())
@parameterized.expand(
[
("LoRA_Llama", adapters.LoRAConfig(), llama_config, llama_adapter_interface, AutoModelForCausalLM),
("LoRA_BERT", adapters.LoRAConfig(), bert_config, bert_adapter_interface, AutoModel),
("LoReft_Llama", adapters.LoReftConfig(), llama_config, llama_adapter_interface, AutoModelForCausalLM),
("LoReft_BERT", adapters.LoReftConfig(), bert_config, bert_adapter_interface, AutoModel),
]
)
def test_load_adapter(self, name, adapter_config, config, adapter_interface, hf_auto_model_class):
self.run_load_test(adapter_config, config, adapter_interface, hf_auto_model_class)

0 comments on commit 2171409

Please sign in to comment.