Skip to content
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,5 @@ dmypy.json

# Pyre type checker
.pyre/

.vscode
4 changes: 2 additions & 2 deletions pyvene/models/configuration_intervenable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import PreTrainedTokenizer, TensorType, is_torch_available
from transformers.configuration_utils import PretrainedConfig

from .interventions import VanillaIntervention
from .interventions import VanillaIntervention, Intervention


RepresentationConfig = namedtuple(
Expand All @@ -25,7 +25,7 @@ class IntervenableConfig(PretrainedConfig):
def __init__(
self,
representations=[RepresentationConfig()],
intervention_types=VanillaIntervention,
intervention_types:type[Intervention] | List[type[Intervention]]=VanillaIntervention,
mode="parallel",
sorted_keys=None,
model_type=None, # deprecating
Expand Down
13 changes: 9 additions & 4 deletions pyvene/models/gpt2/modelings_intervenable_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@

def create_gpt2(name="gpt2", cache_dir=None):
"""Creates a GPT2 model, config, and tokenizer from the given name and revision"""
from transformers import GPT2Model, GPT2Tokenizer, GPT2Config
from transformers import GPT2Model, AutoTokenizer, GPT2Config

config = GPT2Config.from_pretrained(name)
tokenizer = GPT2Tokenizer.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name)
gpt = GPT2Model.from_pretrained(name, config=config, cache_dir=cache_dir)
print("loaded model")
assert isinstance(gpt, GPT2Model)

print(f"loaded GPT2 model {name}")
return config, tokenizer, gpt


Expand All @@ -93,7 +95,10 @@ def create_gpt2_lm(name="gpt2", config=None, cache_dir=None):
gpt = GPT2LMHeadModel.from_pretrained(name, config=config, cache_dir=cache_dir)
else:
gpt = GPT2LMHeadModel(config=config)
print("loaded model")

assert isinstance(gpt, GPT2LMHeadModel)

print(f"loaded GPT2 model {name}")
return config, tokenizer, gpt

def create_gpt2_classifier(name="gpt2", config=None, cache_dir=None):
Expand Down
22 changes: 12 additions & 10 deletions pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
"mlp_activation": ("h[%s].mlp.act", CONST_OUTPUT_HOOK),
"mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK),
"attention_input": ("h[%s].attn", CONST_INPUT_HOOK),
"query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
}


Expand Down Expand Up @@ -67,11 +67,13 @@
def create_gpt_neo(
name="roneneldan/TinyStories-33M", cache_dir=None
):
"""Creates a GPT2 model, config, and tokenizer from the given name and revision"""
"""Creates a GPTNeo model, config, and tokenizer from the given name and revision"""
from transformers import GPTNeoForCausalLM, GPT2Tokenizer, GPTNeoConfig

config = GPTNeoConfig.from_pretrained(name)
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M") # not sure
gpt_neo = GPTNeoForCausalLM.from_pretrained(name)
print("loaded model")
assert isinstance(gpt_neo, GPTNeoForCausalLM)

print(f"loaded GPTNeo model {name}")
return config, tokenizer, gpt_neo
Loading