Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bergson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
)
from .data import load_gradients
from .gradcheck import FiniteDiff
from .gradients import GradientCollector, GradientProcessor
from .gradients import GradientCollector, GradientProcessor, fit_normalizers
from .query.attributor import Attributor
from .query.faiss_index import FaissConfig
from .score.scorer import Scorer

__all__ = [
"collect_gradients",
"load_gradients",
"fit_normalizers",
"Attributor",
"FaissConfig",
"FiniteDiff",
Expand Down
10 changes: 2 additions & 8 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,10 @@ def execute(self):
self.command.execute()


def get_parser():
"""Get the argument parser. Used for documentation generation."""
parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT)
parser.add_arguments(Main, dest="prog")
return parser


def main(args: Optional[list[str]] = None):
"""Parse CLI arguments and dispatch to the selected subcommand."""
parser = get_parser()
parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT)
parser.add_arguments(Main, dest="prog")
prog: Main = parser.parse_args(args=args).prog
prog.execute()

Expand Down
2 changes: 1 addition & 1 deletion bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def build_worker(
)

model, target_modules = setup_model_and_peft(cfg, rank)
processor = create_processor(cfg, rank)
processor = create_processor(model, ds, cfg, rank, target_modules)

attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}

Expand Down
3 changes: 3 additions & 0 deletions bergson/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class IndexConfig:
processor_path: str = ""
"""Path to a precomputed processor."""

normalizer: Literal["adafactor", "adam", "none"] = "none"
"""Type of normalizer to use for the gradients."""

skip_preconditioners: bool = False
"""Whether to skip computing preconditioners for the gradients."""

Expand Down
299 changes: 197 additions & 102 deletions bergson/gradients.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import json
import random
from abc import ABC, abstractmethod
from contextlib import ContextDecorator
from dataclasses import asdict, astuple, dataclass, field
from pathlib import Path
from typing import Callable, Literal, Mapping

import torch
import torch.distributed as dist
import torch.nn as nn
from datasets import Dataset
from torch import Tensor
from torch.utils.hooks import RemovableHandle
from tqdm.auto import tqdm
from transformers import PreTrainedModel
from transformers.pytorch_utils import Conv1D as HFConv1D

from .config import AttentionConfig
from .data import pad_and_tensor
from .math import reshape_to_nearest_square
from .utils import assert_type, create_projection_matrix

Expand Down Expand Up @@ -64,108 +70,6 @@ def state_dict(self) -> dict[str, str | Tensor]:
}


@dataclass
class AdafactorNormalizer(Normalizer):
"""
Row and column sums of second moments of gradients for a matrix-valued parameter.
"""

row: Tensor # shape [O]
col: Tensor # shape [I]

def __post_init__(self):
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"

@torch.compile
def normalize_(
self,
grad: Tensor,
eps: float = 1e-30,
) -> Tensor:
"""
Normalize the row and column sums by adding a small epsilon.

Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
recommend 1e-30, but we use 1e-16 for extra numerical stability.
"""
# We follow the Adafactor implementation in the tensor2tensor repo, which is
# different from the paper and from the PyTorch implementation. First add eps
# to ensure these second moments are sufficiently far from zero. Then we don't
# need to worry about numerical stability anywhere else, and we don't need to
# materialize the outer product at any point.
r, c = self.row.add(eps), self.col.add(eps)

# This is the denominator for V, the rank-one matrix of second moment estimates:
# V = torch.outer(r, c) / denom
# V_ij = r_i * c_j / denom
# But we want to (implicitly) take the Hadamard product with the elementwise
# reciprocal square root of V:
# (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
denom = r.mean()

# Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
# by diag(a) and right-multiplying by diag(b). In this case we can represent
# the elementwise reciprocal square root of V as ab^T where:
# a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
a = denom.sqrt() * r.rsqrt_() # shape [O]
b = c.rsqrt_()

# Implicitly do the Hadamard product
grad *= a[:, None] # [N, O] * [O] → [N, O]
grad *= b[None, :]
return grad

def to_adam(self) -> "AdamNormalizer":
"""
Convert this Adafactor normalizer to an Adam normalizer by materializing the
rank-one second moment matrix.
"""
# Compute the second moment matrix as a square matrix of shape [O, I]
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
# add it outside the square root. This could cause infs though if there are
# any exactly zero rows or columns, so we should be careful.
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
return AdamNormalizer(avg_sq=avg_sq)


@dataclass
class AdamNormalizer(Normalizer):
"""
Contains the second moments of the gradients.
"""

avg_sq: Tensor

@torch.compile
def normalize_(
self,
grad: Tensor,
eps: float = 1e-8,
) -> Tensor:
"""Normalize the gradients by the square root of the second moments."""
# Adam-style epsilon is added outside the square root
denom = self.avg_sq.sqrt()
return grad.div_(denom.add_(eps))

def to_adafactor(self) -> AdafactorNormalizer:
"""
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
I-divergence (generalized Kullback-Leibler divergence) between the original
and the factored second moments.
"""
# We assume avg_sq is a square matrix of shape [O, I]
assert (
self.avg_sq.ndim == 2
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"

# Compute row and column means
return AdafactorNormalizer(
row=self.avg_sq.mean(dim=1), # shape [O]
col=self.avg_sq.mean(dim=0), # shape [I]
)


@dataclass
class GradientProcessor:
"""Configuration for processing and compressing gradients."""
Expand Down Expand Up @@ -626,3 +530,194 @@ def __exit__(self, exc_type, exc, tb):
self._bwd_hooks.clear()

return False


@dataclass
class AdafactorNormalizer(Normalizer):
"""
Row and column sums of second moments of gradients for a matrix-valued parameter.
"""

row: Tensor # shape [O]
col: Tensor # shape [I]

def __post_init__(self):
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"

@torch.compile
def normalize_(
self,
grad: Tensor,
eps: float = 1e-30,
) -> Tensor:
"""
Normalize the row and column sums by adding a small epsilon.

Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
recommend 1e-30, but we use 1e-16 for extra numerical stability.
"""
# We follow the Adafactor implementation in the tensor2tensor repo, which is
# different from the paper and from the PyTorch implementation. First add eps
# to ensure these second moments are sufficiently far from zero. Then we don't
# need to worry about numerical stability anywhere else, and we don't need to
# materialize the outer product at any point.
r, c = self.row.add(eps), self.col.add(eps)

# This is the denominator for V, the rank-one matrix of second moment estimates:
# V = torch.outer(r, c) / denom
# V_ij = r_i * c_j / denom
# But we want to (implicitly) take the Hadamard product with the elementwise
# reciprocal square root of V:
# (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
denom = r.mean()

# Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
# by diag(a) and right-multiplying by diag(b). In this case we can represent
# the elementwise reciprocal square root of V as ab^T where:
# a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
a = denom.sqrt() * r.rsqrt_() # shape [O]
b = c.rsqrt_()

# Implicitly do the Hadamard product
grad *= a[:, None] # [N, O] * [O] → [N, O]
grad *= b[None, :]
return grad

def to_adam(self) -> "AdamNormalizer":
"""
Convert this Adafactor normalizer to an Adam normalizer by materializing the
rank-one second moment matrix.
"""
# Compute the second moment matrix as a square matrix of shape [O, I]
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
# add it outside the square root. This could cause infs though if there are
# any exactly zero rows or columns, so we should be careful.
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
return AdamNormalizer(avg_sq=avg_sq)


@dataclass
class AdamNormalizer(Normalizer):
"""
Contains the second moments of the gradients.
"""

avg_sq: Tensor

@torch.compile
def normalize_(
self,
grad: Tensor,
eps: float = 1e-8,
) -> Tensor:
"""Normalize the gradients by the square root of the second moments."""
# Adam-style epsilon is added outside the square root
denom = self.avg_sq.sqrt()
return grad.div_(denom.add_(eps))

def to_adafactor(self) -> AdafactorNormalizer:
"""
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
I-divergence (generalized Kullback-Leibler divergence) between the original
and the factored second moments.
"""
# We assume avg_sq is a square matrix of shape [O, I]
assert (
self.avg_sq.ndim == 2
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"

# Compute row and column means
return AdafactorNormalizer(
row=self.avg_sq.mean(dim=1), # shape [O]
col=self.avg_sq.mean(dim=0), # shape [I]
)


def fit_normalizers(
model: PreTrainedModel,
data: Dataset,
batches: list[list[int]],
*,
kind: Literal["adafactor", "adam"] = "adafactor",
target_modules: set[str] | None = None,
) -> dict[str, Normalizer]:
"""
Estimate the second moments of the model's gradients using a subset of the dataset.
"""
normalizers: dict[str, Normalizer] = {}
rank = dist.get_rank() if dist.is_initialized() else 0

# Just to make the pbar more accurate
rng = random.Random(0)
rng.shuffle(batches)

def adafactor_update(name: str, g: torch.Tensor):
# We follow the tensor2tensor implementation of Adafactor, which
# takes the mean rather than summing over the rows and columns.
# row: mean over columns, shape [O]
sq = g.float().square_().sum(0)
row_acc = sq.mean(dim=1)
# col: mean over rows, shape [I]
col_acc = sq.mean(dim=0)

if (normalizer := normalizers.get(name)) is None:
# initialize accumulators at zero
normalizers[name] = normalizer = AdafactorNormalizer(
torch.zeros_like(row_acc),
torch.zeros_like(col_acc),
)
else:
assert isinstance(normalizer, AdafactorNormalizer)

# in‐place accumulate
normalizer.row.add_(row_acc)
normalizer.col.add_(col_acc)

def adam_update(name: str, g: torch.Tensor):
sq = g.square_().float().sum(0)

# initialize accumulators at zero
if (normalizer := normalizers.get(name)) is None:
normalizers[name] = normalizer = AdamNormalizer(torch.zeros_like(sq))
else:
assert isinstance(normalizer, AdamNormalizer)

# in‐place accumulate
normalizer.avg_sq.add_(sq)

callback = adafactor_update if kind == "adafactor" else adam_update

for indices in tqdm(batches, disable=rank != 0, desc="Estimating normalizers"):
batch = data[indices]

with GradientCollector(
model.base_model,
closure=callback,
target_modules=target_modules,
):
x, y = pad_and_tensor(
batch["input_ids"], # type: ignore
labels=batch.get("labels", None), # type: ignore
device=model.device,
)
model(x, labels=y).loss.backward()
model.zero_grad()

# Divide by the number of documents processed and average across all ranks
for normalizer in normalizers.values():
if isinstance(normalizer, AdamNormalizer):
normalizer.avg_sq.div_(len(data))

if dist.is_initialized():
dist.all_reduce(normalizer.avg_sq, op=dist.ReduceOp.AVG)

elif isinstance(normalizer, AdafactorNormalizer):
normalizer.row.div_(len(data))
normalizer.col.div_(len(data))

if dist.is_initialized():
dist.all_reduce(normalizer.row, op=dist.ReduceOp.AVG)
dist.all_reduce(normalizer.col, op=dist.ReduceOp.AVG)

return normalizers
2 changes: 1 addition & 1 deletion bergson/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def reduce_worker(
)

model, target_modules = setup_model_and_peft(index_cfg, rank)
processor = create_processor(index_cfg, rank)
processor = create_processor(model, ds, index_cfg, rank, target_modules)

attention_cfgs = {
module: index_cfg.attention for module in index_cfg.split_attention_modules
Expand Down
Loading