Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
Expand All @@ -29,10 +26,10 @@ jobs:
run: pip wheel --no-deps -w dist .
release:
needs: build
if: github.repository == 'EleutherAI/bergson' && github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):')
permissions:
contents: write
id-token: write
if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):')
runs-on: ubuntu-latest
concurrency: release
steps:
Expand Down
62 changes: 62 additions & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: Unit Tests
env:
UV_SYSTEM_PYTHON: 1
on:
push:
branches:
- 'main'
pull_request:
branches:
- 'main'
workflow_dispatch:

jobs:
linter:
name: Linters
runs-on: ubuntu-latest
timeout-minutes: 5

steps:
- uses: actions/checkout@v5
- name: "Set up Python"
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Checkout
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
- name: Install dependencies
run: uv pip install -e ".[dev]" --torch-backend=auto
- name: Pre-Commit
uses: pre-commit/action@v3.0.1
- name: Type Checking
uses: jakebailey/pyright-action@v2
- name: Cleanup
run: uv cache prune --ci

testcpu:
name: CPU Tests
runs-on: ubuntu-latest
strategy:
fail-fast: true
matrix:
python-version: [ "3.10", "3.11", "3.12" ]

steps:
- uses: actions/checkout@v5

- name: Checkout
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ matrix.python-version }}
enable-cache: true

- name: Install dependencies
run: uv sync --extra dev

- name: Run tests
run: uv run pytest tests --showlocals -s -vv -n=auto

- name: Cleanup
run: uv cache prune --ci
24 changes: 11 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
exclude: \.txt$
- repo: https://github.com/psf/black
rev: 25.1.0
- id: trailing-whitespace
- id: end-of-file-fixer
exclude: \.txt$
- id: no-commit-to-branch
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.5
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.11.9'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-check
args: [ --fix ]
- id: ruff-format
3 changes: 1 addition & 2 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def execute(self):

if self.index_cfg.projection_dim != 0:
print(
"Warning: projection_dim is not 0. "
"Compressed gradients will be scored."
"Warning: projection_dim is not 0. Compressed gradients will be scored."
)

score_dataset(self.index_cfg, self.score_cfg)
Expand Down
84 changes: 63 additions & 21 deletions bergson/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,24 @@ def state_dict(self) -> dict[str, str | Tensor]:
class AdafactorNormalizer(Normalizer):
"""
Row and column sums of second moments of gradients for a matrix-valued parameter.

Args:
row: Row statistics [O]
col: Column statistics [I]
bias_avg_sq: Optional second moments for bias [O]
"""

row: Tensor # shape [O]
col: Tensor # shape [I]
bias_avg_sq: Tensor | None = None # shape [O]

def __post_init__(self):
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"
if self.bias_avg_sq is not None:
assert self.bias_avg_sq.ndim == 1, (
f"Expected 1D tensor for bias_avg_sq, got {self.bias_avg_sq.ndim}D"
)

@torch.compile
def normalize_(
Expand Down Expand Up @@ -120,22 +130,44 @@ def to_adam(self) -> "AdamNormalizer":
"""
Convert this Adafactor normalizer to an Adam normalizer by materializing the
rank-one second moment matrix.

Preserves bias_avg_sq if present.
"""
# Compute the second moment matrix as a square matrix of shape [O, I]
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
# add it outside the square root. This could cause infs though if there are
# any exactly zero rows or columns, so we should be careful.
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
return AdamNormalizer(avg_sq=avg_sq)
return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq)

def scale_by_lr(self, lr: float | Tensor) -> "AdafactorNormalizer":
"""Scale normalizer by learning rate.

Factorized dimensions (row, col) are scaled by sqrt(lr).
Bias is scaled by lr.
"""
lr_sqrt = lr**0.5
return AdafactorNormalizer(
row=self.row * lr_sqrt, # shape [O]
col=self.col * lr_sqrt, # shape [I]
bias_avg_sq=self.bias_avg_sq * lr
if self.bias_avg_sq is not None
else None, # shape [O]
)


@dataclass
class AdamNormalizer(Normalizer):
"""
Contains the second moments of the gradients.

Args:
avg_sq: Second moments for weights [O, I]
bias_avg_sq: Optional second moments for bias [O]
"""

avg_sq: Tensor
bias_avg_sq: Tensor | None = None

@torch.compile
def normalize_(
Expand All @@ -153,16 +185,19 @@ def to_adafactor(self) -> AdafactorNormalizer:
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
I-divergence (generalized Kullback-Leibler divergence) between the original
and the factored second moments.

Preserves bias_avg_sq if present.
"""
# We assume avg_sq is a square matrix of shape [O, I]
assert (
self.avg_sq.ndim == 2
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
assert self.avg_sq.ndim == 2, (
f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
)

# Compute row and column means
return AdafactorNormalizer(
row=self.avg_sq.mean(dim=1), # shape [O]
col=self.avg_sq.mean(dim=0), # shape [I]
bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments
)


Expand Down Expand Up @@ -334,8 +369,8 @@ class GradientCollector(ContextDecorator):
of the parameters, which are expected to be precomputed and passed in.

We assume that the input to `model` is of shape `[N, S, I]`, where `N` is the
batch size, `S` is the sequence length, and `I` is the input dimension. We take the
mean over the sequence length to obtain a single gradient per sequence.
batch size, `S` is the sequence length, and `I` is the input dimension. We
sum over the sequence dimension to obtain a single gradient per sequence.
"""

model: nn.Module
Expand Down Expand Up @@ -551,8 +586,22 @@ def _process_grad(self, module: nn.Module, _, grad_out):
i = getattr(module, LayerAdapter.in_attr(module))
o = getattr(module, LayerAdapter.out_attr(module))

# Pre-scale G by the Adafactor row statistics
# Handle bias gradients if needed (must be computed from raw G)
norm = self.processor.normalizers.get(name)
bias_grad = None
if include_bias:
# Compute bias from raw G (before any normalization)
bias_grad = G.sum(dim=1) # [N, S, O] -> [N, O]

# Normalize bias with appropriate second moments
if (
isinstance(norm, (AdamNormalizer, AdafactorNormalizer))
and hasattr(norm, "bias_avg_sq")
and norm.bias_avg_sq is not None
):
bias_grad = bias_grad / norm.bias_avg_sq.sqrt().add_(1e-8)

# Pre-scale G by the Adafactor row statistics (for weight gradients)
if isinstance(norm, AdafactorNormalizer):
# Compare to the normalize_ method in AdafactorNormalizer
r = norm.row.add(1e-30)
Expand All @@ -563,26 +612,19 @@ def _process_grad(self, module: nn.Module, _, grad_out):
# If we are using AdamNormalizer, or including bias gradients
# we need to materialize the full gradient and then project
if isinstance(norm, AdamNormalizer) or include_bias:

P = G.mT @ I # [N, O, S] @ [N, S, I] → [N, O, I]
if include_bias:
# Append the bias gradient to the input
if isinstance(norm, AdamNormalizer):
# Normalize the gradients using the second moment matrix
P /= norm.avg_sq.sqrt().add_(1e-8)

if include_bias and bias_grad is not None:
# Append pre-computed and normalized bias gradient
P = torch.cat(
[
P,
G.sum(dim=(0, 1))
.unsqueeze(0)
.unsqueeze(2)
.expand(P.shape[0], -1, 1),
],
[P, bias_grad.unsqueeze(2)], # [N, O, 1]
dim=2,
)
i += 1

if isinstance(norm, AdamNormalizer):
# Normalize the gradients using the second moment matrix
P /= norm.avg_sq.sqrt().add_(1e-8)

if self.processor.reshape_to_square:
P = reshape_to_nearest_square(P)
o, i = P.shape[-2:]
Expand Down
76 changes: 59 additions & 17 deletions bergson/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,12 @@ def on_step_end(
**kwargs,
):
self.on_substep_end(args, state, control)
print("Step end")

# Record training order if enabled
if self.order is not None:
assert (
self.batch_indices is not None
), "Batch indices are not available for training order tracking"
assert self.batch_indices is not None, (
"Batch indices are not available for training order tracking"
)

epoch = int(state.epoch or 0)
global_step = state.global_step
Expand Down Expand Up @@ -279,32 +278,75 @@ def on_step_end(

# Read normalizers off of the optimizer state. We need to figure out
# what type of optimizer this is first.
# Collect references to both weight and bias second moments per layer
layer_second_moments: dict[str, dict[str, Tensor]] = {}

for group in optimizer.param_groups:
lr_sqrt = group["lr"] ** 0.5
group_lr = group["lr"]

for param in group["params"]:
name = param_to_name[param].removesuffix(".weight")
if name not in self.collector.target_info:
param_name = param_to_name[param]

# Extract layer name (remove .weight or .bias suffix)
if param_name.endswith(".weight"):
param_type = "weight"
layer_name = param_name.removesuffix(".weight")
elif param_name.endswith(".bias"):
param_type = "bias"
layer_name = param_name.removesuffix(".bias")
else:
continue

if layer_name not in self.collector.target_info:
continue

p_state = optimizer.state[param]

# Initialize layer dict if needed, storing this group's learning rate
if layer_name not in layer_second_moments:
layer_second_moments[layer_name] = {"lr": group_lr}

# Adam-like optimizer
if (eas := p_state.get("exp_avg_sq")) is not None:
norm = AdamNormalizer(eas).to_adafactor()

layer_second_moments[layer_name][param_type] = eas
# Adafactor-like optimizer
elif (vr := p_state.get("exp_avg_sq_row")) is not None:
vc = p_state.get("exp_avg_sq_col")
norm = AdafactorNormalizer(vr, vc)
else:
continue
if param_type == "weight":
# Factorized second moments for weights
layer_second_moments[layer_name]["row"] = vr
layer_second_moments[layer_name]["col"] = vc
elif param_type == "bias":
# Adafactor stores bias as regular exp_avg_sq
bias_eas = p_state.get("exp_avg_sq")
if bias_eas is not None:
layer_second_moments[layer_name]["bias"] = bias_eas

# Build normalizers from collected second moments
for layer_name, moments in layer_second_moments.items():
lr = moments["lr"]

# Adam-like: has weight exp_avg_sq
if "weight" in moments:
weight_eas = moments["weight"]
bias_eas = moments.get("bias") # May be None

# Create Adam normalizer with optional bias, then convert to Adafactor
# TODO: always convert to adafactor?
norm = (
AdamNormalizer(weight_eas, bias_eas).to_adafactor().scale_by_lr(lr)
)

# Adafactor-like: has row/col
elif "row" in moments and "col" in moments:
bias_eas = moments.get("bias") # May be present
norm = AdafactorNormalizer(
moments["row"], moments["col"], bias_eas
).scale_by_lr(lr)
else:
continue

# Scale the gradient by the current learning rate. It's factorized
# so we multiply each factor by the square root of the LR.
norm.row *= lr_sqrt
norm.col *= lr_sqrt
normalizers[name] = norm
normalizers[layer_name] = norm

proc.normalizers = normalizers

Expand Down
Loading
Loading