Skip to content

Commit

Permalink
fix: properly parse CLI dict options as json (#423)
Browse files Browse the repository at this point in the history
* fix: properly parse CLI dict options as json

* Update config.py
  • Loading branch information
chanind authored Feb 11, 2025
1 parent 9af8bc0 commit a5ac0f0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
24 changes: 21 additions & 3 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass, field
from typing import Any, Literal, Optional, cast

import simple_parsing
import torch
import wandb
from datasets import (
Expand All @@ -30,6 +31,23 @@
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset


# calling this "json_dict" so error messages will reference "json_dict" being invalid
def json_dict(s: str) -> Any:
res = json.loads(s)
if res is not None and not isinstance(res, dict):
raise ValueError(f"Expected a dictionary, got {type(res)}")
return res


def dict_field(default: dict[str, Any] | None, **kwargs: Any) -> Any: # type: ignore
"""
Helper to wrap simple_parsing.helpers.dict_field so we can load JSON fields from the command line.
"""
if default is None:
return simple_parsing.helpers.field(default=None, type=json_dict, **kwargs)
return simple_parsing.helpers.dict_field(default, type=json_dict, **kwargs)


@dataclass
class LanguageModelSAERunnerConfig:
"""
Expand Down Expand Up @@ -146,7 +164,7 @@ class LanguageModelSAERunnerConfig:
None # defaults to 4 if d_sae and expansion_factor is None
)
activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
activation_fn_kwargs: dict[str, Any] = None # for topk # type: ignore
activation_fn_kwargs: dict[str, int] = dict_field(default=None) # for topk
normalize_sae_decoder: bool = True
noise_scale: float = 0.0
from_pretrained_path: Optional[str] = None
Expand Down Expand Up @@ -238,8 +256,8 @@ class LanguageModelSAERunnerConfig:
n_checkpoints: int = 0
checkpoint_path: str = "checkpoints"
verbose: bool = True
model_kwargs: dict[str, Any] = field(default_factory=dict)
model_from_pretrained_kwargs: dict[str, Any] | None = None
model_kwargs: dict[str, Any] = dict_field(default={})
model_from_pretrained_kwargs: dict[str, Any] | None = dict_field(default=None)
sae_lens_version: str = field(default_factory=lambda: __version__)
sae_lens_training_version: str = field(default_factory=lambda: __version__)
exclude_special_tokens: bool | list[int] = False
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def save_checkpoint(
def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
if len(args) == 0:
args = ["--help"]
parser = ArgumentParser()
parser = ArgumentParser(exit_on_error=False)
parser.add_arguments(LanguageModelSAERunnerConfig, dest="cfg")
return parser.parse_args(args).cfg

Expand Down
39 changes: 39 additions & 0 deletions tests/training/test_sae_training_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import json
import os
from pathlib import Path
Expand Down Expand Up @@ -139,6 +140,44 @@ def test_parse_cfg_args_override():
assert cfg.dataset_path == "my/dataset"


def test_parse_cfg_args_dict_args():
# Test that we can pass dict args as json strings
args = [
"--model_kwargs",
'{"foo": "bar", "baz": 123}',
"--model_from_pretrained_kwargs",
'{"center_writing_weights": false}',
"--activation_fn_kwargs",
'{"k": 100}',
]
cfg = _parse_cfg_args(args)

assert cfg.model_kwargs == {"foo": "bar", "baz": 123}
assert cfg.model_from_pretrained_kwargs == {"center_writing_weights": False}
assert cfg.activation_fn_kwargs == {"k": 100}


def test_parse_cfg_args_invalid_json():
args = ["--model_kwargs", "{invalid json"]
with pytest.raises(argparse.ArgumentError, match="invalid json_dict value"):
_parse_cfg_args(args)


def test_parse_cfg_args_invalid_dict_type():
# Test that we reject non-dict values for dict fields
args = ["--model_kwargs", "[1, 2, 3]"] # Array instead of dict
with pytest.raises(argparse.ArgumentError, match="invalid json_dict value"):
_parse_cfg_args(args)

args = ["--model_from_pretrained_kwargs", '"not_a_dict"'] # String instead of dict
with pytest.raises(argparse.ArgumentError, match="invalid json_dict value"):
_parse_cfg_args(args)

args = ["--activation_fn_kwargs", "123"] # Number instead of dict
with pytest.raises(argparse.ArgumentError, match="invalid json_dict value"):
_parse_cfg_args(args)


def test_parse_cfg_args_expansion_factor():
# Test that we can't set both d_sae and expansion_factor
args = ["--d_sae", "1024", "--expansion_factor", "8"]
Expand Down

0 comments on commit a5ac0f0

Please sign in to comment.