Skip to content

Commit

Permalink
chore: fixing docs sae table for deepseek SAE
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Feb 9, 2025
1 parent 220471b commit 28b226d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 31 deletions.
69 changes: 39 additions & 30 deletions sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,36 +582,19 @@ def get_dictionary_learning_config_1(
}


def deepseek_r1_sae_loader(
release: str,
sae_id: str,
device: str = "cpu",
force_download: bool = False,
cfg_overrides: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict[str, Any], Dict[str, torch.Tensor], Optional[torch.Tensor]]:
# Get repo and file info from pretrained directory
sae_directory = get_pretrained_saes_directory()
repo_id = sae_directory[release].repo_id
filename = sae_directory[release].saes_map[sae_id]

# Download weights
sae_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
force_download=force_download,
)

# Load state dict
state_dict_loaded = torch.load(sae_path, map_location=device)
def get_deepseek_r1_config(
repo_id: str,
folder_name: str,
options: SAEConfigLoadOptions,
) -> dict[str, Any]:
"""Get config for DeepSeek R1 SAEs."""

# Extract layer from filename (l19 in this case)
match = re.search(r"l(\d+)", filename)
match = re.search(r"l(\d+)", folder_name)
if match is None:
raise ValueError(f"Could not find layer number in filename: {filename}")
raise ValueError(f"Could not find layer number in filename: {folder_name}")
layer = int(match.group(1))

# Create config
cfg_dict = {
return {
"architecture": "standard",
"d_in": 4096, # LLaMA 8B hidden size
"d_sae": 4096 * 16, # Expansion factor 16
Expand All @@ -627,11 +610,39 @@ def deepseek_r1_sae_loader(
"sae_lens_training_version": None,
"activation_fn_str": "relu",
"normalize_activations": "none",
"device": device,
"device": options.device,
"apply_b_dec_to_input": False,
"finetuning_scaling_factor": False,
}


def deepseek_r1_sae_loader(
release: str,
sae_id: str,
device: str = "cpu",
force_download: bool = False,
cfg_overrides: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, torch.Tensor], Optional[torch.Tensor]]:
"""Load a DeepSeek R1 SAE."""
# Get repo and file info from pretrained directory
sae_directory = get_pretrained_saes_directory()
repo_id = sae_directory[release].repo_id
filename = sae_directory[release].saes_map[sae_id]

# Download weights
sae_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
force_download=force_download,
)

# Load state dict
state_dict_loaded = torch.load(sae_path, map_location=device)

# Create config
options = SAEConfigLoadOptions(device=device, force_download=force_download)
cfg_dict = get_deepseek_r1_config(repo_id, filename, options)

# Convert weights
state_dict = {
"W_enc": state_dict_loaded["encoder.weight"].T,
Expand Down Expand Up @@ -743,7 +754,5 @@ def dictionary_learning_sae_loader_1(
"gemma_2": get_gemma_2_config,
"llama_scope": get_llama_scope_config,
"dictionary_learning_1": get_dictionary_learning_config_1,
"deepseek_r1": lambda _repo_id,
_folder_name,
_options: {}, # Config built in loader
"deepseek_r1": get_deepseek_r1_config,
}
39 changes: 38 additions & 1 deletion tests/toolkit/test_pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from sae_lens.sae import SAE
from sae_lens.toolkit.pretrained_sae_loaders import SAEConfigLoadOptions, get_sae_config
from sae_lens.toolkit.pretrained_sae_loaders import (
SAEConfigLoadOptions,
get_deepseek_r1_config,
get_sae_config,
)


def test_get_sae_config_sae_lens():
Expand Down Expand Up @@ -178,3 +182,36 @@ def test_get_sae_config_matches_from_pretrained():
)

assert direct_sae_cfg == from_pretrained_cfg_dict


def test_get_deepseek_r1_config():
"""Test that the DeepSeek R1 config is generated correctly."""
options = SAEConfigLoadOptions(device="cpu")
cfg = get_deepseek_r1_config(
repo_id="some/repo",
folder_name="DeepSeek-R1-Distill-Llama-8B-SAE-l19.pt",
options=options,
)

expected_cfg = {
"architecture": "standard",
"d_in": 4096, # LLaMA 8B hidden size
"d_sae": 4096 * 16, # Expansion factor 16
"dtype": "bfloat16",
"context_size": 1024,
"model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"hook_name": "blocks.19.hook_resid_post",
"hook_layer": 19,
"hook_head_index": None,
"prepend_bos": True,
"dataset_path": "lmsys/lmsys-chat-1m",
"dataset_trust_remote_code": True,
"sae_lens_training_version": None,
"activation_fn_str": "relu",
"normalize_activations": "none",
"device": "cpu",
"apply_b_dec_to_input": False,
"finetuning_scaling_factor": False,
}

assert cfg == expected_cfg

0 comments on commit 28b226d

Please sign in to comment.