Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,4 @@ prototype/
influence_results/
.idea/
uv.lock
data/*.hf
3 changes: 2 additions & 1 deletion bergson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ReduceConfig,
ScoreConfig,
)
from .data import load_gradients
from .data import load_gradient_dataset, load_gradients
from .gradcheck import FiniteDiff
from .gradients import GradientCollector, GradientProcessor
from .query.attributor import Attributor
Expand All @@ -19,6 +19,7 @@
__all__ = [
"collect_gradients",
"load_gradients",
"load_gradient_dataset",
"Attributor",
"FaissConfig",
"FiniteDiff",
Expand Down
13 changes: 13 additions & 0 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

from simple_parsing import ArgumentParser, ConflictResolution
Expand All @@ -21,6 +23,17 @@ def execute(self):
if self.index_cfg.skip_index and self.index_cfg.skip_preconditioners:
raise ValueError("Either skip_index or skip_preconditioners must be False")

# Require confirmation from the user to proceed if overwriting an existing index
index_path = Path(self.index_cfg.run_path) / "gradients.bin"
if not self.index_cfg.skip_index and index_path.exists():
confirm = input(
f"File {index_path} already exists. Delete and proceed? (y/n): "
)
if confirm.lower() != "y":
exit()
else:
shutil.rmtree(index_path.parent)

build(self.index_cfg)


Expand Down
7 changes: 4 additions & 3 deletions bergson/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ class DataConfig:
subset: str | None = None
"""Subset of the dataset to use for building the index."""

streaming: bool = False
"""Whether to use streaming mode for the dataset."""

prompt_column: str = "text"
"""Column in the dataset that contains the prompts."""

Expand All @@ -36,6 +33,10 @@ class DataConfig:
truncation: bool = False
"""Whether to truncate long documents to fit the token budget."""

data_args: str = ""
"""Arguments to pass to the dataset constructor in the format
arg1=val1,arg2=val2."""


@dataclass
class AttentionConfig:
Expand Down
11 changes: 4 additions & 7 deletions bergson/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from numpy.typing import DTypeLike

from .config import DataConfig
from .utils import assert_type
from .utils import assert_type, simple_parse_args_string


def ceildiv(a: int, b: int) -> int:
Expand Down Expand Up @@ -182,10 +182,6 @@ def create_index(
# Ensure the directory exists
root.mkdir(parents=True, exist_ok=True)

# Ensure no existing file is overwritten
if grad_path.exists():
raise FileExistsError(f"File {grad_path} already exists.")

# Allocate (extends file to right size without writing zeros byte-by-byte)
nbytes = struct_dtype["itemsize"] * num_grads
with open(grad_path, "wb") as f:
Expand Down Expand Up @@ -230,7 +226,7 @@ def load_data_string(
data_str: str,
split: str = "train",
subset: str | None = None,
streaming: bool = False,
data_args: str = "",
) -> Dataset | IterableDataset:
"""Load a dataset from a string identifier or path."""
if data_str.endswith(".csv"):
Expand All @@ -239,7 +235,8 @@ def load_data_string(
ds = assert_type(Dataset, Dataset.from_json(data_str))
else:
try:
ds = load_dataset(data_str, subset, split=split, streaming=streaming)
kwargs = simple_parse_args_string(data_args)
ds = load_dataset(data_str, subset, split=split, **kwargs)

if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
raise NotImplementedError(
Expand Down
30 changes: 30 additions & 0 deletions bergson/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,33 @@ def create_projection_matrix(
raise ValueError(f"Unknown projection type: {projection_type}")
A /= A.norm(dim=1, keepdim=True)
return A


def handle_arg_string(arg: str):
if arg.lower() == "true":
return True
elif arg.lower() == "false":
return False
elif arg.isnumeric():
return int(arg)
try:
return float(arg)
except ValueError:
return arg


def simple_parse_args_string(args_string: str) -> dict[str, Any]:
"""
Parses something like
args1=val1,arg2=val2
into a dictionary.
"""
args_string = args_string.strip()
if not args_string:
return {}
arg_list = [arg for arg in args_string.split(",") if arg]
args_dict = {
kv[0]: handle_arg_string("=".join(kv[1:]))
for kv in [arg.split("=") for arg in arg_list]
}
return args_dict
34 changes: 5 additions & 29 deletions bergson/worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
import torch
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from peft import PeftConfig, PeftModel, get_peft_model_state_dict
from torch.distributed.fsdp import fully_shard
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from bergson.config import DataConfig, IndexConfig
from bergson.data import tokenize
from bergson.data import load_data_string, tokenize
from bergson.gradients import GradientProcessor
from bergson.utils import assert_type, get_layer_list

Expand Down Expand Up @@ -84,6 +81,7 @@ def setup_model_and_peft(
try:
peft_config = PeftConfig.from_pretrained(cfg.model)
except ValueError:
print(f"PEFT config not found for model {cfg.model}")
peft_config = None

if peft_config is None:
Expand Down Expand Up @@ -156,31 +154,9 @@ def estimate_advantage(ds: Dataset, cfg: DataConfig):

def setup_data_pipeline(cfg: IndexConfig) -> Dataset | IterableDataset:
"""Handle data loading and preprocessing"""

data_str = cfg.data.dataset
if data_str.endswith(".csv"):
ds = assert_type(Dataset, Dataset.from_csv(data_str))
elif data_str.endswith(".json") or data_str.endswith(".jsonl"):
ds = assert_type(Dataset, Dataset.from_json(data_str))
else:
try:
ds = load_dataset(
data_str,
cfg.data.subset,
split=cfg.data.split,
streaming=cfg.data.streaming,
)

if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
raise NotImplementedError(
"DatasetDicts and IterableDatasetDicts are not supported."
)
except ValueError as e:
# Automatically use load_from_disk if appropriate
if "load_from_disk" in str(e):
ds = Dataset.load_from_disk(data_str, keep_in_memory=False)
else:
raise e
ds = load_data_string(
cfg.data.dataset, cfg.data.split, cfg.data.subset, cfg.data.data_args
)

# In many cases the token_batch_size may be smaller than the max length allowed by
# the model. If cfg.data.truncation is True, we use the tokenizer to truncate
Expand Down
17 changes: 17 additions & 0 deletions data/generate_facts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from argparse import ArgumentParser

from datasets import Dataset

from .dataset import fact_generator

if __name__ == "__main__":
from argparse import ArgumentParser

from datasets import Dataset

parser = ArgumentParser()
parser.add_argument("--num_facts", type=int, default=1000)
args = parser.parse_args()

dataset = fact_generator(args.num_facts)
Dataset.from_list(list(dataset)).save_to_disk("data/facts_dataset.hf")
Loading