Skip to content

Commit 437b9cd

Browse files
committed
Add back optimizer-aware gradients
1 parent 0efc830 commit 437b9cd

File tree

7 files changed

+259
-114
lines changed

7 files changed

+259
-114
lines changed

bergson/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
)
1212
from .data import load_gradients
1313
from .gradcheck import FiniteDiff
14-
from .gradients import GradientCollector, GradientProcessor
14+
from .gradients import GradientCollector, GradientProcessor, fit_normalizers
1515
from .query.attributor import Attributor
1616
from .query.faiss_index import FaissConfig
1717
from .score.scorer import Scorer
1818

1919
__all__ = [
2020
"collect_gradients",
2121
"load_gradients",
22+
"fit_normalizers",
2223
"Attributor",
2324
"FaissConfig",
2425
"FiniteDiff",

bergson/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def build_worker(
5656
)
5757

5858
model, target_modules = setup_model_and_peft(cfg, rank)
59-
processor = create_processor(cfg, rank)
59+
processor = create_processor(model, ds, cfg, rank, target_modules)
6060

6161
attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}
6262

bergson/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class IndexConfig:
8888
processor_path: str = ""
8989
"""Path to a precomputed processor."""
9090

91+
normalizer: Literal["adafactor", "adam", "none"] = "none"
92+
"""Type of normalizer to use for the gradients."""
93+
9194
skip_preconditioners: bool = False
9295
"""Whether to skip computing preconditioners for the gradients."""
9396

bergson/gradients.py

Lines changed: 197 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
import json
2+
import random
23
from abc import ABC, abstractmethod
34
from contextlib import ContextDecorator
45
from dataclasses import asdict, astuple, dataclass, field
56
from pathlib import Path
67
from typing import Callable, Literal, Mapping
78

89
import torch
10+
import torch.distributed as dist
911
import torch.nn as nn
12+
from datasets import Dataset
1013
from torch import Tensor
1114
from torch.utils.hooks import RemovableHandle
15+
from tqdm.auto import tqdm
16+
from transformers import PreTrainedModel
1217
from transformers.pytorch_utils import Conv1D as HFConv1D
1318

1419
from .config import AttentionConfig
20+
from .data import pad_and_tensor
1521
from .math import reshape_to_nearest_square
1622
from .utils import assert_type, create_projection_matrix
1723

@@ -64,108 +70,6 @@ def state_dict(self) -> dict[str, str | Tensor]:
6470
}
6571

6672

67-
@dataclass
68-
class AdafactorNormalizer(Normalizer):
69-
"""
70-
Row and column sums of second moments of gradients for a matrix-valued parameter.
71-
"""
72-
73-
row: Tensor # shape [O]
74-
col: Tensor # shape [I]
75-
76-
def __post_init__(self):
77-
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
78-
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"
79-
80-
@torch.compile
81-
def normalize_(
82-
self,
83-
grad: Tensor,
84-
eps: float = 1e-30,
85-
) -> Tensor:
86-
"""
87-
Normalize the row and column sums by adding a small epsilon.
88-
89-
Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
90-
recommend 1e-30, but we use 1e-16 for extra numerical stability.
91-
"""
92-
# We follow the Adafactor implementation in the tensor2tensor repo, which is
93-
# different from the paper and from the PyTorch implementation. First add eps
94-
# to ensure these second moments are sufficiently far from zero. Then we don't
95-
# need to worry about numerical stability anywhere else, and we don't need to
96-
# materialize the outer product at any point.
97-
r, c = self.row.add(eps), self.col.add(eps)
98-
99-
# This is the denominator for V, the rank-one matrix of second moment estimates:
100-
# V = torch.outer(r, c) / denom
101-
# V_ij = r_i * c_j / denom
102-
# But we want to (implicitly) take the Hadamard product with the elementwise
103-
# reciprocal square root of V:
104-
# (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
105-
denom = r.mean()
106-
107-
# Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
108-
# by diag(a) and right-multiplying by diag(b). In this case we can represent
109-
# the elementwise reciprocal square root of V as ab^T where:
110-
# a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
111-
a = denom.sqrt() * r.rsqrt_() # shape [O]
112-
b = c.rsqrt_()
113-
114-
# Implicitly do the Hadamard product
115-
grad *= a[:, None] # [N, O] * [O] → [N, O]
116-
grad *= b[None, :]
117-
return grad
118-
119-
def to_adam(self) -> "AdamNormalizer":
120-
"""
121-
Convert this Adafactor normalizer to an Adam normalizer by materializing the
122-
rank-one second moment matrix.
123-
"""
124-
# Compute the second moment matrix as a square matrix of shape [O, I]
125-
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
126-
# add it outside the square root. This could cause infs though if there are
127-
# any exactly zero rows or columns, so we should be careful.
128-
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
129-
return AdamNormalizer(avg_sq=avg_sq)
130-
131-
132-
@dataclass
133-
class AdamNormalizer(Normalizer):
134-
"""
135-
Contains the second moments of the gradients.
136-
"""
137-
138-
avg_sq: Tensor
139-
140-
@torch.compile
141-
def normalize_(
142-
self,
143-
grad: Tensor,
144-
eps: float = 1e-8,
145-
) -> Tensor:
146-
"""Normalize the gradients by the square root of the second moments."""
147-
# Adam-style epsilon is added outside the square root
148-
denom = self.avg_sq.sqrt()
149-
return grad.div_(denom.add_(eps))
150-
151-
def to_adafactor(self) -> AdafactorNormalizer:
152-
"""
153-
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
154-
I-divergence (generalized Kullback-Leibler divergence) between the original
155-
and the factored second moments.
156-
"""
157-
# We assume avg_sq is a square matrix of shape [O, I]
158-
assert (
159-
self.avg_sq.ndim == 2
160-
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
161-
162-
# Compute row and column means
163-
return AdafactorNormalizer(
164-
row=self.avg_sq.mean(dim=1), # shape [O]
165-
col=self.avg_sq.mean(dim=0), # shape [I]
166-
)
167-
168-
16973
@dataclass
17074
class GradientProcessor:
17175
"""Configuration for processing and compressing gradients."""
@@ -626,3 +530,194 @@ def __exit__(self, exc_type, exc, tb):
626530
self._bwd_hooks.clear()
627531

628532
return False
533+
534+
535+
@dataclass
536+
class AdafactorNormalizer(Normalizer):
537+
"""
538+
Row and column sums of second moments of gradients for a matrix-valued parameter.
539+
"""
540+
541+
row: Tensor # shape [O]
542+
col: Tensor # shape [I]
543+
544+
def __post_init__(self):
545+
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
546+
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"
547+
548+
@torch.compile
549+
def normalize_(
550+
self,
551+
grad: Tensor,
552+
eps: float = 1e-30,
553+
) -> Tensor:
554+
"""
555+
Normalize the row and column sums by adding a small epsilon.
556+
557+
Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
558+
recommend 1e-30, but we use 1e-16 for extra numerical stability.
559+
"""
560+
# We follow the Adafactor implementation in the tensor2tensor repo, which is
561+
# different from the paper and from the PyTorch implementation. First add eps
562+
# to ensure these second moments are sufficiently far from zero. Then we don't
563+
# need to worry about numerical stability anywhere else, and we don't need to
564+
# materialize the outer product at any point.
565+
r, c = self.row.add(eps), self.col.add(eps)
566+
567+
# This is the denominator for V, the rank-one matrix of second moment estimates:
568+
# V = torch.outer(r, c) / denom
569+
# V_ij = r_i * c_j / denom
570+
# But we want to (implicitly) take the Hadamard product with the elementwise
571+
# reciprocal square root of V:
572+
# (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
573+
denom = r.mean()
574+
575+
# Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
576+
# by diag(a) and right-multiplying by diag(b). In this case we can represent
577+
# the elementwise reciprocal square root of V as ab^T where:
578+
# a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
579+
a = denom.sqrt() * r.rsqrt_() # shape [O]
580+
b = c.rsqrt_()
581+
582+
# Implicitly do the Hadamard product
583+
grad *= a[:, None] # [N, O] * [O] → [N, O]
584+
grad *= b[None, :]
585+
return grad
586+
587+
def to_adam(self) -> "AdamNormalizer":
588+
"""
589+
Convert this Adafactor normalizer to an Adam normalizer by materializing the
590+
rank-one second moment matrix.
591+
"""
592+
# Compute the second moment matrix as a square matrix of shape [O, I]
593+
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
594+
# add it outside the square root. This could cause infs though if there are
595+
# any exactly zero rows or columns, so we should be careful.
596+
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
597+
return AdamNormalizer(avg_sq=avg_sq)
598+
599+
600+
@dataclass
601+
class AdamNormalizer(Normalizer):
602+
"""
603+
Contains the second moments of the gradients.
604+
"""
605+
606+
avg_sq: Tensor
607+
608+
@torch.compile
609+
def normalize_(
610+
self,
611+
grad: Tensor,
612+
eps: float = 1e-8,
613+
) -> Tensor:
614+
"""Normalize the gradients by the square root of the second moments."""
615+
# Adam-style epsilon is added outside the square root
616+
denom = self.avg_sq.sqrt()
617+
return grad.div_(denom.add_(eps))
618+
619+
def to_adafactor(self) -> AdafactorNormalizer:
620+
"""
621+
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
622+
I-divergence (generalized Kullback-Leibler divergence) between the original
623+
and the factored second moments.
624+
"""
625+
# We assume avg_sq is a square matrix of shape [O, I]
626+
assert (
627+
self.avg_sq.ndim == 2
628+
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
629+
630+
# Compute row and column means
631+
return AdafactorNormalizer(
632+
row=self.avg_sq.mean(dim=1), # shape [O]
633+
col=self.avg_sq.mean(dim=0), # shape [I]
634+
)
635+
636+
637+
def fit_normalizers(
638+
model: PreTrainedModel,
639+
data: Dataset,
640+
batches: list[list[int]],
641+
*,
642+
kind: Literal["adafactor", "adam"] = "adafactor",
643+
target_modules: set[str] | None = None,
644+
) -> dict[str, Normalizer]:
645+
"""
646+
Estimate the second moments of the model's gradients using a subset of the dataset.
647+
"""
648+
normalizers: dict[str, Normalizer] = {}
649+
rank = dist.get_rank() if dist.is_initialized() else 0
650+
651+
# Just to make the pbar more accurate
652+
rng = random.Random(0)
653+
rng.shuffle(batches)
654+
655+
def adafactor_update(name: str, g: torch.Tensor):
656+
# We follow the tensor2tensor implementation of Adafactor, which
657+
# takes the mean rather than summing over the rows and columns.
658+
# row: mean over columns, shape [O]
659+
sq = g.float().square_().sum(0)
660+
row_acc = sq.mean(dim=1)
661+
# col: mean over rows, shape [I]
662+
col_acc = sq.mean(dim=0)
663+
664+
if (normalizer := normalizers.get(name)) is None:
665+
# initialize accumulators at zero
666+
normalizers[name] = normalizer = AdafactorNormalizer(
667+
torch.zeros_like(row_acc),
668+
torch.zeros_like(col_acc),
669+
)
670+
else:
671+
assert isinstance(normalizer, AdafactorNormalizer)
672+
673+
# in‐place accumulate
674+
normalizer.row.add_(row_acc)
675+
normalizer.col.add_(col_acc)
676+
677+
def adam_update(name: str, g: torch.Tensor):
678+
sq = g.square_().float().sum(0)
679+
680+
# initialize accumulators at zero
681+
if (normalizer := normalizers.get(name)) is None:
682+
normalizers[name] = normalizer = AdamNormalizer(torch.zeros_like(sq))
683+
else:
684+
assert isinstance(normalizer, AdamNormalizer)
685+
686+
# in‐place accumulate
687+
normalizer.avg_sq.add_(sq)
688+
689+
callback = adafactor_update if kind == "adafactor" else adam_update
690+
691+
for indices in tqdm(batches, disable=rank != 0, desc="Estimating normalizers"):
692+
batch = data[indices]
693+
694+
with GradientCollector(
695+
model.base_model,
696+
closure=callback,
697+
target_modules=target_modules,
698+
):
699+
x, y = pad_and_tensor(
700+
batch["input_ids"], # type: ignore
701+
labels=batch.get("labels", None), # type: ignore
702+
device=model.device,
703+
)
704+
model(x, labels=y).loss.backward()
705+
model.zero_grad()
706+
707+
# Divide by the number of documents processed and average across all ranks
708+
for normalizer in normalizers.values():
709+
if isinstance(normalizer, AdamNormalizer):
710+
normalizer.avg_sq.div_(len(data))
711+
712+
if dist.is_initialized():
713+
dist.all_reduce(normalizer.avg_sq, op=dist.ReduceOp.AVG)
714+
715+
elif isinstance(normalizer, AdafactorNormalizer):
716+
normalizer.row.div_(len(data))
717+
normalizer.col.div_(len(data))
718+
719+
if dist.is_initialized():
720+
dist.all_reduce(normalizer.row, op=dist.ReduceOp.AVG)
721+
dist.all_reduce(normalizer.col, op=dist.ReduceOp.AVG)
722+
723+
return normalizers

bergson/reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def reduce_worker(
5959
)
6060

6161
model, target_modules = setup_model_and_peft(index_cfg, rank)
62-
processor = create_processor(index_cfg, rank)
62+
processor = create_processor(model, ds, index_cfg, rank, target_modules)
6363

6464
attention_cfgs = {
6565
module: index_cfg.attention for module in index_cfg.split_attention_modules

bergson/score/score.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
from dataclasses import asdict
55
from datetime import timedelta
66
from pathlib import Path
7-
from typing import Literal, cast
7+
from typing import Literal
88

99
import torch
1010
import torch.distributed as dist
1111
from datasets import Dataset, IterableDataset
1212
from tqdm.auto import tqdm
13-
from transformers import PreTrainedModel
1413

1514
from bergson.collection import collect_gradients
1615
from bergson.config import IndexConfig, ScoreConfig
@@ -250,8 +249,7 @@ def score_worker(
250249
)
251250

252251
model, target_modules = setup_model_and_peft(index_cfg, rank)
253-
model = cast(PreTrainedModel, model)
254-
processor = create_processor(index_cfg, rank)
252+
processor = create_processor(model, ds, index_cfg, rank, target_modules)
255253

256254
attention_cfgs = {
257255
module: index_cfg.attention for module in index_cfg.split_attention_modules

0 commit comments

Comments
 (0)