From 4f01357b2f72686101e12978526abf31858cbf2d Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Mon, 17 Nov 2025 22:38:46 +0000 Subject: [PATCH 1/7] Accept arbitrary dataset args --- bergson/config.py | 7 ++++--- bergson/data.py | 7 ++++--- bergson/utils.py | 30 +++++++++++++++++++++++++++++ bergson/worker_utils.py | 2 +- examples/trainer_grad_collection.py | 4 +++- 5 files changed, 42 insertions(+), 8 deletions(-) diff --git a/bergson/config.py b/bergson/config.py index 497e56f2..1d562163 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -16,9 +16,6 @@ class DataConfig: subset: str | None = None """Subset of the dataset to use for building the index.""" - streaming: bool = False - """Whether to use streaming mode for the dataset.""" - prompt_column: str = "text" """Column in the dataset that contains the prompts.""" @@ -36,6 +33,10 @@ class DataConfig: truncation: bool = False """Whether to truncate long documents to fit the token budget.""" + data_args: str = "" + """Arguments to pass to the dataset constructor in the format + arg1=val1,arg2=val2.""" + @dataclass class AttentionConfig: diff --git a/bergson/data.py b/bergson/data.py index d5074f8f..af86399e 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -20,7 +20,7 @@ from numpy.typing import DTypeLike from .config import DataConfig -from .utils import assert_type +from .utils import assert_type, simple_parse_args_string def ceildiv(a: int, b: int) -> int: @@ -230,7 +230,7 @@ def load_data_string( data_str: str, split: str = "train", subset: str | None = None, - streaming: bool = False, + data_args: str = "", ) -> Dataset | IterableDataset: """Load a dataset from a string identifier or path.""" if data_str.endswith(".csv"): @@ -239,7 +239,8 @@ def load_data_string( ds = assert_type(Dataset, Dataset.from_json(data_str)) else: try: - ds = load_dataset(data_str, subset, split=split, streaming=streaming) + kwargs = simple_parse_args_string(data_args) + ds = load_dataset(data_str, subset, split=split, **kwargs) if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict): raise NotImplementedError( diff --git a/bergson/utils.py b/bergson/utils.py index 7529e4bf..ac4ffa40 100644 --- a/bergson/utils.py +++ b/bergson/utils.py @@ -58,3 +58,33 @@ def create_projection_matrix( raise ValueError(f"Unknown projection type: {projection_type}") A /= A.norm(dim=1, keepdim=True) return A + + +def handle_arg_string(arg: str): + if arg.lower() == "true": + return True + elif arg.lower() == "false": + return False + elif arg.isnumeric(): + return int(arg) + try: + return float(arg) + except ValueError: + return arg + + +def simple_parse_args_string(args_string: str) -> dict[str, Any]: + """ + Parses something like + args1=val1,arg2=val2 + into a dictionary. + """ + args_string = args_string.strip() + if not args_string: + return {} + arg_list = [arg for arg in args_string.split(",") if arg] + args_dict = { + kv[0]: handle_arg_string("=".join(kv[1:])) + for kv in [arg.split("=") for arg in arg_list] + } + return args_dict diff --git a/bergson/worker_utils.py b/bergson/worker_utils.py index 49848d0d..f58e60cf 100644 --- a/bergson/worker_utils.py +++ b/bergson/worker_utils.py @@ -168,7 +168,7 @@ def setup_data_pipeline(cfg: IndexConfig) -> Dataset | IterableDataset: data_str, cfg.data.subset, split=cfg.data.split, - streaming=cfg.data.streaming, + data_args=cfg.data.data_args, ) if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict): diff --git a/examples/trainer_grad_collection.py b/examples/trainer_grad_collection.py index 990c3ec1..acdefd51 100644 --- a/examples/trainer_grad_collection.py +++ b/examples/trainer_grad_collection.py @@ -118,7 +118,9 @@ def main(args: IndexConfig): conversation_column=args.data.conversation_column, ) dataset = load_data_string( - args.data.dataset, args.data.split, streaming=args.data.streaming + args.data.dataset, + args.data.split, + data_args=args.data.data_args, ) dataset = dataset.map( tokenize, From 1e93b7631b2708d002ef2b2fce453a1a9f4acef5 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 19 Nov 2025 06:09:44 +0000 Subject: [PATCH 2/7] Add reword data script --- examples/semantic.py | 80 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 examples/semantic.py diff --git a/examples/semantic.py b/examples/semantic.py new file mode 100644 index 00000000..add6ff54 --- /dev/null +++ b/examples/semantic.py @@ -0,0 +1,80 @@ +# python -m data.generate_facts --num_facts 1000 + +import torch +from datasets import Dataset, load_from_disk +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def reword(dataset, model_name: str, prompt_template: str): + device = "cuda:0" + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, + ) + model.eval() + + def generate(text: str): + inputs = tokenizer(text, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=128, + do_sample=True, + temperature=0.7, + top_p=0.8, + min_p=0.0, + ) + + generated = tokenizer.decode(outputs[0], skip_special_tokens=True) + # Remove the prompt prefix. + return generated[len(text) :].strip() + + # Process each item in dataset + new_items = [] + for item in dataset: + fact = item["fact"] + + prompt = prompt_template.format(fact=fact) + output = generate(prompt) + + item["reworded"] = output + print(output) + + new_items.append(item) + + return Dataset.from_list(new_items) + + +def main(): + dataset = load_from_disk("data/facts_dataset.hf") + # model_name = "Meta-Llama/Meta-Llama-3-8B-Instruct" + model_name = "Qwen/Qwen3-8B-Base" + + prompt = ( + "Reword the following fact in a Shakespearean style, adding " + "flair and poetry.\n " + "Do not include other text in your response, " + "just the contents of the reworded fact.\n " + "Fact: {fact}\n " + "Your rewrite: (remember, no notes or explanations):" + ) + + reword(dataset, model_name, prompt).to_disk("data/facts_dataset_shakespeare.hf") + + prompt = ( + "Reword the following fact like it's coming from a pirate. Be creative!\n " + "Do not include any other text in your response, " + "just the contents of the reworded fact.\n " + "Fact: {fact}\n " + "Your rewrite: (remember, no notes or explanations):" + ) + + reword(dataset, model_name, prompt).to_disk("data/facts_dataset_pirate.hf") + + +if __name__ == "__main__": + main() From b7a8ffe7f25da9c65c9ae4101bf424d1a20b3882 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 19 Nov 2025 06:10:44 +0000 Subject: [PATCH 3/7] Use ds util --- bergson/__init__.py | 3 ++- bergson/worker_utils.py | 34 +++++----------------------------- 2 files changed, 7 insertions(+), 30 deletions(-) diff --git a/bergson/__init__.py b/bergson/__init__.py index 34725170..8eafb8be 100644 --- a/bergson/__init__.py +++ b/bergson/__init__.py @@ -9,7 +9,7 @@ ReduceConfig, ScoreConfig, ) -from .data import load_gradients +from .data import load_gradient_dataset, load_gradients from .gradcheck import FiniteDiff from .gradients import GradientCollector, GradientProcessor from .query.attributor import Attributor @@ -19,6 +19,7 @@ __all__ = [ "collect_gradients", "load_gradients", + "load_gradient_dataset", "Attributor", "FaissConfig", "FiniteDiff", diff --git a/bergson/worker_utils.py b/bergson/worker_utils.py index f58e60cf..2f31396a 100644 --- a/bergson/worker_utils.py +++ b/bergson/worker_utils.py @@ -4,17 +4,14 @@ import torch from datasets import ( Dataset, - DatasetDict, IterableDataset, - IterableDatasetDict, - load_dataset, ) from peft import PeftConfig, PeftModel, get_peft_model_state_dict from torch.distributed.fsdp import fully_shard from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from bergson.config import DataConfig, IndexConfig -from bergson.data import tokenize +from bergson.data import load_data_string, tokenize from bergson.gradients import GradientProcessor from bergson.utils import assert_type, get_layer_list @@ -84,6 +81,7 @@ def setup_model_and_peft( try: peft_config = PeftConfig.from_pretrained(cfg.model) except ValueError: + print(f"PEFT config not found for model {cfg.model}") peft_config = None if peft_config is None: @@ -156,31 +154,9 @@ def estimate_advantage(ds: Dataset, cfg: DataConfig): def setup_data_pipeline(cfg: IndexConfig) -> Dataset | IterableDataset: """Handle data loading and preprocessing""" - - data_str = cfg.data.dataset - if data_str.endswith(".csv"): - ds = assert_type(Dataset, Dataset.from_csv(data_str)) - elif data_str.endswith(".json") or data_str.endswith(".jsonl"): - ds = assert_type(Dataset, Dataset.from_json(data_str)) - else: - try: - ds = load_dataset( - data_str, - cfg.data.subset, - split=cfg.data.split, - data_args=cfg.data.data_args, - ) - - if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict): - raise NotImplementedError( - "DatasetDicts and IterableDatasetDicts are not supported." - ) - except ValueError as e: - # Automatically use load_from_disk if appropriate - if "load_from_disk" in str(e): - ds = Dataset.load_from_disk(data_str, keep_in_memory=False) - else: - raise e + ds = load_data_string( + cfg.data.dataset, cfg.data.split, cfg.data.subset, cfg.data.data_args + ) # In many cases the token_batch_size may be smaller than the max length allowed by # the model. If cfg.data.truncation is True, we use the tokenizer to truncate From 43a858ee0fcc7f7aa76250e616085bbcbec82d3b Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 21 Nov 2025 05:21:47 +0000 Subject: [PATCH 4/7] Add semantic experiment finetuning --- data/generate_facts.py | 17 +++ examples/semantic.py | 220 +++++++++++++++++++++++------ examples/train_lora.py | 309 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 506 insertions(+), 40 deletions(-) create mode 100644 data/generate_facts.py create mode 100644 examples/train_lora.py diff --git a/data/generate_facts.py b/data/generate_facts.py new file mode 100644 index 00000000..32e0e3a9 --- /dev/null +++ b/data/generate_facts.py @@ -0,0 +1,17 @@ +from argparse import ArgumentParser + +from datasets import Dataset + +from .dataset import fact_generator + +if __name__ == "__main__": + from argparse import ArgumentParser + + from datasets import Dataset + + parser = ArgumentParser() + parser.add_argument("--num_facts", type=int, default=1000) + args = parser.parse_args() + + dataset = fact_generator(args.num_facts) + Dataset.from_list(list(dataset)).save_to_disk("data/facts_dataset.hf") diff --git a/examples/semantic.py b/examples/semantic.py index add6ff54..da135ad3 100644 --- a/examples/semantic.py +++ b/examples/semantic.py @@ -1,14 +1,20 @@ -# python -m data.generate_facts --num_facts 1000 +from pathlib import Path +import subprocess import torch -from datasets import Dataset, load_from_disk +from tqdm import tqdm +from datasets import Dataset, load_from_disk, concatenate_datasets from transformers import AutoModelForCausalLM, AutoTokenizer -def reword(dataset, model_name: str, prompt_template: str): - device = "cuda:0" +def reword(dataset, model_name: str, prompt_template: str, batch_size: int = 8): + device = "cuda:3" tokenizer = AutoTokenizer.from_pretrained(model_name) + + # REQUIRED for batched generation with Llama/Qwen/Mistral tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, @@ -16,64 +22,198 @@ def reword(dataset, model_name: str, prompt_template: str): ) model.eval() - def generate(text: str): - inputs = tokenizer(text, return_tensors="pt").to(model.device) + new_facts = [] + new_reworded = [] + + # Convert dataset to list for easy slicing + # (Assuming the dataset is small enough to fit in RAM, which 1000 items is) + data_list = list(dataset) + + print(f"Starting generation with batch size: {batch_size}...") + + for i in tqdm(range(0, len(data_list), batch_size)): + # 1. Prepare the batch + batch_items = data_list[i : i + batch_size] + prompts = [prompt_template.format(fact=item["fact"]) for item in batch_items] + # 2. Tokenize (Batch mode) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + input_len = inputs.input_ids.shape[1] + + # 3. Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=128, + pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.7, top_p=0.8, min_p=0.0, ) - generated = tokenizer.decode(outputs[0], skip_special_tokens=True) - # Remove the prompt prefix. - return generated[len(text) :].strip() + # 4. Slice output to remove prompt (all at once) + # With left-padding, the prompt is always the first 'input_len' tokens + generated_tokens = outputs[:, input_len:] + + # 5. Decode batch + decoded_batch = tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + ) - # Process each item in dataset - new_items = [] - for item in dataset: - fact = item["fact"] + # 6. Store results + for item, output_text in zip(batch_items, decoded_batch): + new_facts.append(item["fact"]) + new_reworded.append(output_text.strip()) - prompt = prompt_template.format(fact=fact) - output = generate(prompt) + # Reconstruct dataset + return Dataset.from_dict({"fact": new_facts, "reworded": new_reworded}) - item["reworded"] = output - print(output) - new_items.append(item) +def create_data(): + dataset = load_from_disk("data/facts_dataset.hf") - return Dataset.from_list(new_items) + for model_name in ["Qwen/Qwen3-8B-Base", "Meta-Llama/Meta-Llama-3-8B"]: + + # 1. Shakespeare + prompt_shake = ( + "Reword the following fact in a Shakespearean style, adding flair and poetry.\n" + "Do not include other text in your response, just the contents of the reworded fact.\n" + "Fact: {fact}\n" + "Your rewrite:" + ) + + ds_shake = reword(dataset, model_name, prompt_shake, batch_size=8) + ds_shake.save_to_disk( + f"data/facts_dataset_shakespeare-{model_name.split('/')[-1]}.hf" + ) + print("Shakespearean processing done.") + + # 2. Pirate + prompt_pirate = ( + "Reword the following fact like it's coming from a pirate. Be creative!\n" + "Do not include any other text in your response, just the contents of the reworded fact.\n" + "Fact: {fact}\n" + "Your rewrite:" + ) + + ds_pirate = reword(dataset, model_name, prompt_pirate, batch_size=8) + ds_pirate.save_to_disk( + f"data/facts_dataset_pirate-{model_name.split('/')[-1]}.hf" + ) + print("Pirate processing done.") + + +def create_index(dataset_name, analysis_model_name): + run_path = Path(f"runs/{dataset_name}") + cmd = [ + "bergson", + "build", + str(run_path), + "--model", + analysis_model_name, + "--dataset", + dataset_name, + "--drop_columns", + "False", + "--prompt_column", + "fact", + "--completion_column", + "reworded", + "--fsdp", + "--projection_dim", + "128", + "--skip_preconditioners", + ] + + print(" ".join(cmd)) + if not run_path.exists(): + result = subprocess.run(cmd, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + +def finetune(dataset_path, analysis_model_name, finetuned_model_path): + cmd = [ + "torchrun", + "--nproc_per_node=8", + "--master_port=29500", + "--standalone", + "examples/train_lora.py", + # "examples/finetune_sem.py", + "--dataset_name", + dataset_path, + "--finetuned_model_path", + finetuned_model_path, + "--model_name", + analysis_model_name, + "--prompt_column", + "fact", + "--completion_column", + "reworded", + ] + print(" ".join(cmd)) + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, # "Pipe" the output to us + stderr=subprocess.STDOUT, # Merge errors into the standard output stream + text=True, # Decode bytes to string automatically + bufsize=1 # Line buffering (updates every line) + ) as process: + # Iterate over the output line by line as it comes in + for line in process.stdout: # type: ignore + print(line.strip()) + + result = subprocess.run(cmd, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) def main(): - dataset = load_from_disk("data/facts_dataset.hf") - # model_name = "Meta-Llama/Meta-Llama-3-8B-Instruct" - model_name = "Qwen/Qwen3-8B-Base" - - prompt = ( - "Reword the following fact in a Shakespearean style, adding " - "flair and poetry.\n " - "Do not include other text in your response, " - "just the contents of the reworded fact.\n " - "Fact: {fact}\n " - "Your rewrite: (remember, no notes or explanations):" - ) + # create_data() + dataset_paths = [ + "data/facts_dataset_shakespeare-Qwen3-8B-Base.hf", + "data/facts_dataset_pirate-Qwen3-8B-Base.hf", + "data/facts_dataset_shakespeare-Meta-Llama-3-8B.hf", + "data/facts_dataset_pirate-Meta-Llama-3-8B.hf", + ] - reword(dataset, model_name, prompt).to_disk("data/facts_dataset_shakespeare.hf") + original = load_from_disk("data/facts_dataset.hf") - prompt = ( - "Reword the following fact like it's coming from a pirate. Be creative!\n " - "Do not include any other text in your response, " - "just the contents of the reworded fact.\n " - "Fact: {fact}\n " - "Your rewrite: (remember, no notes or explanations):" - ) + merged_datasets = [] + + for path in dataset_paths: + ds = load_from_disk(path) + + # Add back any dropped columns from original + for col in original.column_names: + if col not in ds.column_names: + # Align ds length with original by matching on "fact" + # Create a mapping from fact → row + orig_map = {row["fact"]: row for row in original} + + # Build list for restored column + restored_col = [orig_map[row["fact"]][col] for row in ds] + + ds = ds.add_column(col, restored_col) + + merged_datasets.append(ds) + + final_dataset = concatenate_datasets(merged_datasets) + final_dataset = final_dataset.shuffle(seed=42) + + final_dataset_path = "data/facts_dataset_reworded.hf" + final_dataset.save_to_disk(final_dataset_path) + print(f"Merged dataset saved to: {final_dataset_path}") + + analysis_model_name = "Qwen/Qwen3-4B" + + finetuned_model_path = f"finetuned-{final_dataset_path.split('/')[-1].split('.')[0]}-{analysis_model_name}" + # Finetune model on dataset + finetune(final_dataset_path, analysis_model_name, finetuned_model_path) - reword(dataset, model_name, prompt).to_disk("data/facts_dataset_pirate.hf") + # Build index with finetuned model + create_index(final_dataset_path, finetuned_model_path) if __name__ == "__main__": diff --git a/examples/train_lora.py b/examples/train_lora.py new file mode 100644 index 00000000..84d460d8 --- /dev/null +++ b/examples/train_lora.py @@ -0,0 +1,309 @@ +import json +import os +import sys + +import backoff +import torch +import torch.distributed as dist +from datasets import Dataset +from peft import LoraConfig, prepare_model_for_kbit_training +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from trl import SFTTrainer, SFTConfig + +from torch.utils.data import SequentialSampler +from datasets import load_dataset + +from typing import List, Literal, Optional, Union + +from pydantic import BaseModel, Field, field_validator +from bergson.config import IndexConfig, DataConfig +from bergson.worker_utils import setup_data_pipeline + + +class TrainingConfig(BaseModel): + class Config: + extra = "forbid" # Prevent extra fields not defined in the model + + # Required model and data paths + model: str = Field(..., description="Hugging Face model ID") + dataset: str = Field(..., description="Dataset") + split: str = Field(..., description="Split") + + prompt_column: str = Field("prompt", description="Prompt column") + completion_column: str = Field("completion", description="Completion column") + + # Training type configuration + loss: Literal["dpo", "orpo", "sft"] = Field( + ..., description="Loss function / training type" + ) + + # Output model + finetuned_model_id: Optional[str] = Field( + None, description="File ID of the finetuned model" + ) + + # Model configuration + max_seq_length: int = Field( + 2048, description="Maximum sequence length for training" + ) + load_in_4bit: bool = Field( + False, description="Whether to load model in 4-bit quantization" + ) + + # PEFT configuration + is_peft: bool = Field(True, description="Whether to use PEFT for training") + target_modules: Optional[List[str]] = Field( + default=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + description="Target modules for LoRA", + ) + lora_bias: Literal["all", "none"] = Field( + "none", description="Value for FastLanguageModel.get_peft_model(bias=?)" + ) + + # LoRA specific arguments + r: int = Field(16, description="LoRA attention dimension") + lora_alpha: int = Field(16, description="LoRA alpha parameter") + lora_dropout: float = Field(0.0, description="LoRA dropout rate") + use_rslora: bool = Field(True, description="Whether to use RSLoRA") + merge_before_push: bool = Field( + True, + description="Whether to merge model before pushing to Hub. Only merged models can be used as parent models for further finetunes. Only supported for bf16 models.", + ) + push_to_private: bool = Field(True, description="Whether to push to private Hub") + + # Training hyperparameters + epochs: int = Field(1, description="Number of training epochs") + max_steps: int = Field(-1, description="Maximum number of training steps") + per_device_train_batch_size: int = Field( + 2, description="Training batch size per device" + ) + gradient_accumulation_steps: int = Field( + 8, description="Number of gradient accumulation steps" + ) + warmup_steps: int = Field(5, description="Number of warmup steps") + learning_rate: Union[float, str] = Field( + 1e-4, description="Learning rate or string expression" + ) + logging_steps: int = Field(1, description="Number of steps between logging") + optim: str = Field("adamw_8bit", description="Optimizer to use for training") + weight_decay: float = Field(0.01, description="Weight decay rate") + lr_scheduler_type: str = Field("linear", description="Learning rate scheduler type") + seed: Optional[int] = Field(None, description="Random seed for reproducibility") + save_steps: int = Field(5000, description="Save checkpoint every X steps") + output_dir: str = Field( + "./tmp", description="Output directory for training checkpoints" + ) + + @field_validator("finetuned_model_id") + def validate_finetuned_model_id(cls, v): + # if v and model_exists(v): + # raise ValueError(f"Model {v} already exists") + if len(v.split("/")) != 2: + raise ValueError("Model ID must be in the format 'user/model'") + org, model = v.split("/") + if org in ["datasets", "models", "unsloth", "None"]: + raise ValueError( + f"You have set org={org}, but it must be an org you have access to" + ) + return v + + @field_validator("learning_rate", mode="before") + def validate_learning_rate(cls, v): + if isinstance(v, float) and v <= 0: + raise ValueError("Learning rate must be positive") + return v + + @field_validator("lora_dropout") + def validate_dropout(cls, v): + if not 0 <= v <= 1: + raise ValueError("Dropout rate must be between 0 and 1") + return v + + @field_validator("lr_scheduler_type") + def validate_scheduler(cls, v): + allowed_schedulers = [ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + ] + if v not in allowed_schedulers: + raise ValueError(f"Scheduler must be one of {allowed_schedulers}") + return v + + +# def process(df, prompt_column: str = "prompt", completion_column: str = "completion"): +# def format_chat_data(example): +# old_example = example +# example["prompt"] = [{"role": "user", "content": old_example[prompt_column]}] +# example["completion"] = [ +# {"role": "assistant", "content": old_example[completion_column]} +# ] +# return example + +# df = df.map(format_chat_data) +# return df + + +class NoShuffleSFTTrainer(SFTTrainer): + def _get_train_sampler(self, dataset): # <-- Add 'dataset' parameter + sampler = SequentialSampler(dataset) + + return sampler + + +def train(training_cfg: TrainingConfig, dataset: Dataset): + """Prepare lora model, call training function, and push to hub""" + + if rank := os.environ.get("LOCAL_RANK"): + rank = int(rank) + dist.init_process_group("nccl", device_id=torch.device(f"cuda:{rank}")) + else: + rank = 0 + + print("Creating new LoRA adapter") + target_modules = training_cfg.target_modules + model = AutoModelForCausalLM.from_pretrained( + training_cfg.model, + device_map={"": f"cuda:{rank}"}, + quantization_config=BitsAndBytesConfig( + load_in_8bit=True, + ), + ) + tokenizer = AutoTokenizer.from_pretrained( + training_cfg.model, token=os.environ.get("HF_TOKEN"), max_length=2048 + ) + # Prepare for k-bit training + model = prepare_model_for_kbit_training(model) + + # 3. Define LoRA config + peft_config = LoraConfig( + r=training_cfg.r, + lora_alpha=training_cfg.lora_alpha, + target_modules=target_modules, + lora_dropout=training_cfg.lora_dropout, + use_rslora=training_cfg.use_rslora, + bias=training_cfg.lora_bias, + task_type="CAUSAL_LM", + ) + + # dataset = process( + # dataset, + # prompt_column=training_cfg.prompt_column, + # completion_column=training_cfg.completion_column, + # ) + if training_cfg.seed is not None: + dataset = dataset.shuffle(seed=training_cfg.seed) + + trainer = NoShuffleSFTTrainer( + model=model, + train_dataset=dataset, + args=SFTConfig( + completion_only_loss=True, + ddp_find_unused_parameters=False, + fp16=True, + gradient_accumulation_steps=training_cfg.gradient_accumulation_steps, + learning_rate=training_cfg.learning_rate, + logging_steps=1, + lr_scheduler_type=training_cfg.lr_scheduler_type, + max_length=training_cfg.max_seq_length, + max_steps=training_cfg.max_steps, + num_train_epochs=training_cfg.epochs, + label_names=["labels"], + optim=training_cfg.optim, + output_dir=training_cfg.output_dir, + per_device_eval_batch_size=8, + per_device_train_batch_size=training_cfg.per_device_train_batch_size, + report_to=None, + save_steps=training_cfg.save_steps, + warmup_steps=training_cfg.warmup_steps, + weight_decay=training_cfg.weight_decay, + ), + peft_config=peft_config, + callbacks=[], + ) + trainer.train() + + if rank == 0: + if training_cfg.finetuned_model_id is not None: + push_model(training_cfg, training_cfg.finetuned_model_id, model, tokenizer) + + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +@backoff.on_exception(backoff.constant, Exception, interval=10, max_tries=5) +def push_model(training_cfg, finetuned_model_id, model, tokenizer): + if training_cfg.merge_before_push: + model.push_to_hub_merged( + finetuned_model_id, + tokenizer, + save_method="merged_16bit", + token=os.environ["HF_TOKEN"], + private=training_cfg.push_to_private, + ) + else: + model.push_to_hub( + finetuned_model_id, + token=os.environ["HF_TOKEN"], + private=training_cfg.push_to_private, + ) + tokenizer.push_to_hub( + finetuned_model_id, + token=os.environ["HF_TOKEN"], + private=training_cfg.push_to_private, + ) + + +def main(): + from argparse import ArgumentParser + + parser = ArgumentParser() + # model_name = "Qwen/Qwen2.5-7B" + parser.add_argument("--finetuned_model_path", type=str, default="finetuned-model") + parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-4B") + parser.add_argument("--dataset_name", type=str, default="HuggingFaceH4/MATH-500") + parser.add_argument("--split", type=str, default="test") + parser.add_argument("--prompt_column", type=str, default="prompt") + parser.add_argument("--completion_column", type=str, default="completion") + + args = parser.parse_args() + + training_config = TrainingConfig( # type: ignore + finetuned_model_id=args.finetuned_model_path, # type: ignore + model=args.model_name, # type: ignore + dataset=args.dataset_name,# type: ignore + split=args.split, # type: ignore + loss="sft", # type: ignore + prompt_column=args.prompt_column, # type: ignore + completion_column=args.completion_column, # type: ignore + ) # type: ignore + + dataset = setup_data_pipeline( + IndexConfig( + run_path=f"runs/{args.finetuned_model_path}", + model=args.model_name, + data=DataConfig( + dataset=args.dataset_name, + split=args.split, + prompt_column=args.prompt_column, + completion_column=args.completion_column, + ), + ) + ) + train(training_config, dataset) + + +if __name__ == "__main__": + main() From 66a80de3c110defbe87bc4c6fec2c9267957cec3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 05:00:21 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/semantic.py | 18 +++++++++--------- examples/train_lora.py | 35 +++++++++++++++-------------------- 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/examples/semantic.py b/examples/semantic.py index da135ad3..52a23078 100644 --- a/examples/semantic.py +++ b/examples/semantic.py @@ -1,9 +1,9 @@ -from pathlib import Path import subprocess +from pathlib import Path import torch +from datasets import Dataset, concatenate_datasets, load_from_disk from tqdm import tqdm -from datasets import Dataset, load_from_disk, concatenate_datasets from transformers import AutoModelForCausalLM, AutoTokenizer @@ -154,16 +154,16 @@ def finetune(dataset_path, analysis_model_name, finetuned_model_path): ] print(" ".join(cmd)) with subprocess.Popen( - cmd, + cmd, stdout=subprocess.PIPE, # "Pipe" the output to us - stderr=subprocess.STDOUT, # Merge errors into the standard output stream - text=True, # Decode bytes to string automatically - bufsize=1 # Line buffering (updates every line) + stderr=subprocess.STDOUT, # Merge errors into the standard output stream + text=True, # Decode bytes to string automatically + bufsize=1, # Line buffering (updates every line) ) as process: # Iterate over the output line by line as it comes in - for line in process.stdout: # type: ignore - print(line.strip()) - + for line in process.stdout: # type: ignore + print(line.strip()) + result = subprocess.run(cmd, capture_output=True, text=True) print(result.stdout) print(result.stderr) diff --git a/examples/train_lora.py b/examples/train_lora.py index 84d460d8..43b9d316 100644 --- a/examples/train_lora.py +++ b/examples/train_lora.py @@ -1,22 +1,17 @@ -import json import os -import sys +from typing import List, Literal, Optional, Union import backoff import torch import torch.distributed as dist from datasets import Dataset from peft import LoraConfig, prepare_model_for_kbit_training -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from trl import SFTTrainer, SFTConfig - +from pydantic import BaseModel, Field, field_validator from torch.utils.data import SequentialSampler -from datasets import load_dataset - -from typing import List, Literal, Optional, Union +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from trl import SFTConfig, SFTTrainer -from pydantic import BaseModel, Field, field_validator -from bergson.config import IndexConfig, DataConfig +from bergson.config import DataConfig, IndexConfig from bergson.worker_utils import setup_data_pipeline @@ -277,18 +272,18 @@ def main(): parser.add_argument("--split", type=str, default="test") parser.add_argument("--prompt_column", type=str, default="prompt") parser.add_argument("--completion_column", type=str, default="completion") - + args = parser.parse_args() - training_config = TrainingConfig( # type: ignore - finetuned_model_id=args.finetuned_model_path, # type: ignore - model=args.model_name, # type: ignore - dataset=args.dataset_name,# type: ignore - split=args.split, # type: ignore - loss="sft", # type: ignore - prompt_column=args.prompt_column, # type: ignore - completion_column=args.completion_column, # type: ignore - ) # type: ignore + training_config = TrainingConfig( # type: ignore + finetuned_model_id=args.finetuned_model_path, # type: ignore + model=args.model_name, # type: ignore + dataset=args.dataset_name, # type: ignore + split=args.split, # type: ignore + loss="sft", # type: ignore + prompt_column=args.prompt_column, # type: ignore + completion_column=args.completion_column, # type: ignore + ) # type: ignore dataset = setup_data_pipeline( IndexConfig( From 234d0f231e9a32da0e5e3bb57367d3277ba5048a Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 25 Nov 2025 22:03:46 +0000 Subject: [PATCH 6/7] save progress --- .gitignore | 1 + examples/semantic.py | 86 +++++++++++++++------------ examples/semantics_experiment.py | 68 +++++++++++++++++++++ examples/slurm/data_parallel_score.sh | 2 +- examples/train_lora.py | 40 ++++++------- 5 files changed, 137 insertions(+), 60 deletions(-) create mode 100644 examples/semantics_experiment.py diff --git a/.gitignore b/.gitignore index 16c1f3ca..a71ff5d4 100644 --- a/.gitignore +++ b/.gitignore @@ -187,3 +187,4 @@ prototype/ influence_results/ .idea/ uv.lock +data/*.hf diff --git a/examples/semantic.py b/examples/semantic.py index da135ad3..6969bcc3 100644 --- a/examples/semantic.py +++ b/examples/semantic.py @@ -1,9 +1,9 @@ -from pathlib import Path import subprocess +from pathlib import Path import torch +from datasets import Dataset, concatenate_datasets, load_from_disk from tqdm import tqdm -from datasets import Dataset, load_from_disk, concatenate_datasets from transformers import AutoModelForCausalLM, AutoTokenizer @@ -77,8 +77,10 @@ def create_data(): # 1. Shakespeare prompt_shake = ( - "Reword the following fact in a Shakespearean style, adding flair and poetry.\n" - "Do not include other text in your response, just the contents of the reworded fact.\n" + "Reword the following fact in a Shakespearean style, adding flair and " + "poetry.\n" + "Do not include other text in your response, just the contents of the " + "reworded fact.\n" "Fact: {fact}\n" "Your rewrite:" ) @@ -92,7 +94,8 @@ def create_data(): # 2. Pirate prompt_pirate = ( "Reword the following fact like it's coming from a pirate. Be creative!\n" - "Do not include any other text in your response, just the contents of the reworded fact.\n" + "Do not include any other text in your response, just the contents of the " + "reworded fact.\n" "Fact: {fact}\n" "Your rewrite:" ) @@ -109,7 +112,7 @@ def create_index(dataset_name, analysis_model_name): cmd = [ "bergson", "build", - str(run_path), + str(run_path / "index"), "--model", analysis_model_name, "--dataset", @@ -122,11 +125,12 @@ def create_index(dataset_name, analysis_model_name): "reworded", "--fsdp", "--projection_dim", - "128", + "16", "--skip_preconditioners", ] print(" ".join(cmd)) + exit() if not run_path.exists(): result = subprocess.run(cmd, capture_output=True, text=True) print(result.stdout) @@ -154,16 +158,16 @@ def finetune(dataset_path, analysis_model_name, finetuned_model_path): ] print(" ".join(cmd)) with subprocess.Popen( - cmd, + cmd, stdout=subprocess.PIPE, # "Pipe" the output to us - stderr=subprocess.STDOUT, # Merge errors into the standard output stream - text=True, # Decode bytes to string automatically - bufsize=1 # Line buffering (updates every line) + stderr=subprocess.STDOUT, # Merge errors into the standard output stream + text=True, # Decode bytes to string automatically + bufsize=1, # Line buffering (updates every line) ) as process: # Iterate over the output line by line as it comes in - for line in process.stdout: # type: ignore - print(line.strip()) - + for line in process.stdout: # type: ignore + print(line.strip()) + result = subprocess.run(cmd, capture_output=True, text=True) print(result.stdout) print(result.stderr) @@ -178,42 +182,48 @@ def main(): "data/facts_dataset_pirate-Meta-Llama-3-8B.hf", ] - original = load_from_disk("data/facts_dataset.hf") + final_dataset_path = "data/facts_dataset_reworded.hf" + # analysis_model_name = "Qwen/Qwen3-4B" + # finetuned_model_path = ( + # f"finetuned-{final_dataset_path.split('/')[-1].split('.')[0]}" + # f"-{analysis_model_name}" + # ) - merged_datasets = [] + if not Path(final_dataset_path).exists(): + original = load_from_disk("data/facts_dataset.hf") - for path in dataset_paths: - ds = load_from_disk(path) + merged_datasets = [] - # Add back any dropped columns from original - for col in original.column_names: - if col not in ds.column_names: - # Align ds length with original by matching on "fact" - # Create a mapping from fact → row - orig_map = {row["fact"]: row for row in original} + for path in dataset_paths: + ds = load_from_disk(path) - # Build list for restored column - restored_col = [orig_map[row["fact"]][col] for row in ds] + # Add back any dropped columns from original + for col in original.column_names: + if col not in ds.column_names: + # Align ds length with original by matching on "fact" + # Create a mapping from fact → row + orig_map = {row["fact"]: row for row in original} - ds = ds.add_column(col, restored_col) + # Build list for restored column + restored_col = [orig_map[row["fact"]][col] for row in ds] - merged_datasets.append(ds) + ds = ds.add_column(col, restored_col) - final_dataset = concatenate_datasets(merged_datasets) - final_dataset = final_dataset.shuffle(seed=42) + merged_datasets.append(ds) - final_dataset_path = "data/facts_dataset_reworded.hf" - final_dataset.save_to_disk(final_dataset_path) - print(f"Merged dataset saved to: {final_dataset_path}") + final_dataset = concatenate_datasets(merged_datasets) + final_dataset = final_dataset.shuffle(seed=42) - analysis_model_name = "Qwen/Qwen3-4B" + final_dataset.save_to_disk(final_dataset_path) + print(f"Merged dataset saved to: {final_dataset_path}") - finetuned_model_path = f"finetuned-{final_dataset_path.split('/')[-1].split('.')[0]}-{analysis_model_name}" - # Finetune model on dataset - finetune(final_dataset_path, analysis_model_name, finetuned_model_path) + # if not Path(finetuned_model_path).exists(): + # # Finetune model on dataset + # finetune(final_dataset_path, analysis_model_name, finetuned_model_path) # Build index with finetuned model - create_index(final_dataset_path, finetuned_model_path) + tmp_path = "tmp/checkpoint-282" + create_index(final_dataset_path, tmp_path) if __name__ == "__main__": diff --git a/examples/semantics_experiment.py b/examples/semantics_experiment.py new file mode 100644 index 00000000..b4b30db1 --- /dev/null +++ b/examples/semantics_experiment.py @@ -0,0 +1,68 @@ +import subprocess +from pathlib import Path + +import torch +from datasets import load_dataset + +from bergson import load_gradient_dataset + +dataset = load_dataset("HuggingFaceH4/MATH-500", split="test") + +# Build Bergson index +run_path = Path("runs/math-500/gemma") +cmd = [ + "bergson", + "build", + str(run_path), + "--model", + "google/gemma-3-4b-it", + "--dataset", + "HuggingFaceH4/MATH-500", + "--drop_columns", + "False", + "--split", + "test", + "--prompt_column", + "problem", + "--completion_column", + "answer", +] +print(" ".join(cmd)) + +if not run_path.exists(): + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + +# Check whether items with the same subject value have a greater cosine similarity score +# Than items from dissimilar subjects + +gradient_ds = load_gradient_dataset(run_path, structured=False) + +subjects = gradient_ds["subject"] + +# Compute cosine similarity between all items' gradients +gradients = torch.tensor(gradient_ds["gradients"], device="cuda") +gradients /= gradients.norm(dim=1, keepdim=True) +similarities = gradients @ gradients.T + + +# Check whether items with the same subject value have a greater cosine similarity score +# Than items from dissimilar subjects +intra_subject_similarities = [] +inter_subject_similarities = [] + +for i in range(len(gradients)): + for j in range(i + 1, len(gradients)): + if subjects[i] == subjects[j]: + intra_subject_similarities.append(similarities[i, j]) + else: + inter_subject_similarities.append(similarities[i, j]) + + +mean_intra_subject_similarity = torch.mean(torch.tensor(intra_subject_similarities)) +mean_inter_subject_similarity = torch.mean(torch.tensor(inter_subject_similarities)) +print(f"Intra-subject similarity mean: {mean_intra_subject_similarity}") +print(f"Inter-subject similarity mean: {mean_inter_subject_similarity}") + +breakpoint() diff --git a/examples/slurm/data_parallel_score.sh b/examples/slurm/data_parallel_score.sh index 4254d804..9b8a47f8 100644 --- a/examples/slurm/data_parallel_score.sh +++ b/examples/slurm/data_parallel_score.sh @@ -16,7 +16,7 @@ hf auth login --token NUM_NODES=64 RUN_NAME="bergson_score" -TOTAL_EXAMPLES=$(cat dataset_size.txt) +TOTAL_EXAMPLES=100_000_000 EXAMPLES_PER_NODE=$((TOTAL_EXAMPLES / NUM_NODES)) # Export variables for the worker script diff --git a/examples/train_lora.py b/examples/train_lora.py index 84d460d8..343c3594 100644 --- a/examples/train_lora.py +++ b/examples/train_lora.py @@ -1,22 +1,17 @@ -import json import os -import sys +from typing import List, Literal, Optional, Union import backoff import torch import torch.distributed as dist from datasets import Dataset from peft import LoraConfig, prepare_model_for_kbit_training -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from trl import SFTTrainer, SFTConfig - +from pydantic import BaseModel, Field, field_validator from torch.utils.data import SequentialSampler -from datasets import load_dataset - -from typing import List, Literal, Optional, Union +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from trl import SFTConfig, SFTTrainer -from pydantic import BaseModel, Field, field_validator -from bergson.config import IndexConfig, DataConfig +from bergson.config import DataConfig, IndexConfig from bergson.worker_utils import setup_data_pipeline @@ -75,7 +70,9 @@ class Config: use_rslora: bool = Field(True, description="Whether to use RSLoRA") merge_before_push: bool = Field( True, - description="Whether to merge model before pushing to Hub. Only merged models can be used as parent models for further finetunes. Only supported for bf16 models.", + # description="Whether to merge model before pushing to Hub. Only merged models + # can be used as parent models for further finetunes. Only supported for + # bf16 models.", ) push_to_private: bool = Field(True, description="Whether to push to private Hub") @@ -277,18 +274,19 @@ def main(): parser.add_argument("--split", type=str, default="test") parser.add_argument("--prompt_column", type=str, default="prompt") parser.add_argument("--completion_column", type=str, default="completion") - + args = parser.parse_args() - training_config = TrainingConfig( # type: ignore - finetuned_model_id=args.finetuned_model_path, # type: ignore - model=args.model_name, # type: ignore - dataset=args.dataset_name,# type: ignore - split=args.split, # type: ignore - loss="sft", # type: ignore - prompt_column=args.prompt_column, # type: ignore - completion_column=args.completion_column, # type: ignore - ) # type: ignore + training_config = TrainingConfig( # type: ignore + finetuned_model_id=args.finetuned_model_path, # type: ignore + model=args.model_name, # type: ignore + dataset=args.dataset_name, # type: ignore + split=args.split, # type: ignore + loss="sft", # type: ignore + prompt_column=args.prompt_column, # type: ignore + completion_column=args.completion_column, # type: ignore + merge_before_push=False, + ) # type: ignore dataset = setup_data_pipeline( IndexConfig( From 056d4dca159afd5017df367940a1dd8a7e9aa3ef Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 27 Nov 2025 00:55:00 +0000 Subject: [PATCH 7/7] Verify run path upfront --- bergson/__main__.py | 13 +++++++++++++ bergson/data.py | 4 ---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/bergson/__main__.py b/bergson/__main__.py index 7b4c55cc..af26dde7 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -1,4 +1,6 @@ +import shutil from dataclasses import dataclass +from pathlib import Path from typing import Optional, Union from simple_parsing import ArgumentParser, ConflictResolution @@ -21,6 +23,17 @@ def execute(self): if self.index_cfg.skip_index and self.index_cfg.skip_preconditioners: raise ValueError("Either skip_index or skip_preconditioners must be False") + # Require confirmation from the user to proceed if overwriting an existing index + index_path = Path(self.index_cfg.run_path) / "gradients.bin" + if not self.index_cfg.skip_index and index_path.exists(): + confirm = input( + f"File {index_path} already exists. Delete and proceed? (y/n): " + ) + if confirm.lower() != "y": + exit() + else: + shutil.rmtree(index_path.parent) + build(self.index_cfg) diff --git a/bergson/data.py b/bergson/data.py index af86399e..5aa51a79 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -182,10 +182,6 @@ def create_index( # Ensure the directory exists root.mkdir(parents=True, exist_ok=True) - # Ensure no existing file is overwritten - if grad_path.exists(): - raise FileExistsError(f"File {grad_path} already exists.") - # Allocate (extends file to right size without writing zeros byte-by-byte) nbytes = struct_dtype["itemsize"] * num_grads with open(grad_path, "wb") as f: