From 3ea0b9c25839d035de595c195f9fd63f460abef5 Mon Sep 17 00:00:00 2001 From: Baber Abbasi Date: Tue, 11 Nov 2025 17:31:52 +0000 Subject: [PATCH 1/9] add req --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b45c4ed..0953c36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,11 +12,13 @@ license = {text = "MIT License"} dependencies = [ "accelerate", # For device_map in from_pretrained "datasets", + "matplotlib>=3.10.7", "natsort", "peft>=0.17.0", "simple-parsing", "torch", "transformers", + "wandb>=0.22.3", ] version = "0.1.1" [project.optional-dependencies] From e1bb1038dff20da8c6d9f29e714609d4d8fc9d41 Mon Sep 17 00:00:00 2001 From: Baber Date: Fri, 14 Nov 2025 16:44:30 +0500 Subject: [PATCH 2/9] ci: add unit tests workflow and restrict build to upstream repo --- .github/workflows/build.yml | 2 +- .github/workflows/unit_tests.yml | 54 ++++++++++++++++++++++++++++++++ pyproject.toml | 5 +-- 3 files changed, 58 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/unit_tests.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cafcef4..db5738b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,10 +29,10 @@ jobs: run: pip wheel --no-deps -w dist . release: needs: build + if: github.repository == 'EleutherAI/bergson' && github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):') permissions: contents: write id-token: write - if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):') runs-on: ubuntu-latest concurrency: release steps: diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml new file mode 100644 index 0000000..b284055 --- /dev/null +++ b/.github/workflows/unit_tests.yml @@ -0,0 +1,54 @@ +name: Unit Tests + +on: + push: + branches: + - 'main' + pull_request: + branches: + - 'main' + workflow_dispatch: + +jobs: + linter: + name: Linters + runs-on: ubuntu-latest + timeout-minutes: 5 + + steps: + - name: Pre-commit Checks + uses: actions/checkout@v5 + - name: Set up Python 3.10 + uses: actions/setup-python@v6 + with: + python-version: '3.10' + cache: pip + cache-dependency-path: pyproject.toml + - name: Pre-Commit + uses: pre-commit/action@v3.0.1 + + testcpu: + name: CPU Tests + runs-on: ubuntu-latest + strategy: + fail-fast: true + matrix: + python-version: [ "3.10", "3.11", "3.12" ] + + steps: + - uses: actions/checkout@v5 + + - name: Checkout + uses: astral-sh/setup-uv@v6 + with: + python-version: ${{ matrix.python-version }} + enable-cache: true + + - name: Install dependencies + run: uv sync --locked --dev + + - name: Run tests + run: uv run pytest --showlocals -s -vv -n=auto tests + + - name: Cleanup + run: uv cache prune --ci diff --git a/pyproject.toml b/pyproject.toml index d804354..cc38719 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,21 +12,22 @@ license = {text = "MIT License"} dependencies = [ "accelerate", # For device_map in from_pretrained "datasets", - "matplotlib>=3.10.7", "natsort", "peft>=0.17.0", "simple-parsing", "torch", "transformers", - "wandb>=0.22.3", ] version = "0.2.0" [project.optional-dependencies] dev = [ + "matplotlib", "pre-commit", "pytest", + "pytest-xdist", "pyright", "trl", + "wandb" ] example = [ "trl", From 50d4d0d6ee06ae35f67e38a693937db1208495a6 Mon Sep 17 00:00:00 2001 From: Baber Date: Fri, 14 Nov 2025 21:02:20 +0500 Subject: [PATCH 3/9] nit --- .github/workflows/build.yml | 3 --- .github/workflows/unit_tests.yml | 9 ++++++--- pyproject.toml | 1 + 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index db5738b..e3cb92f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,9 +4,6 @@ on: push: branches: - main - pull_request: - branches: - - main jobs: build: runs-on: ubuntu-latest diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index b284055..8db2d1a 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -1,5 +1,6 @@ name: Unit Tests - +env: + UV_SYSTEM_PYTHON: 1 on: push: branches: @@ -26,6 +27,8 @@ jobs: cache-dependency-path: pyproject.toml - name: Pre-Commit uses: pre-commit/action@v3.0.1 + - name: Type Checking + uses: jakebailey/pyright-action@v1 testcpu: name: CPU Tests @@ -45,10 +48,10 @@ jobs: enable-cache: true - name: Install dependencies - run: uv sync --locked --dev + run: uv sync --extra dev - name: Run tests - run: uv run pytest --showlocals -s -vv -n=auto tests + run: uv run pytest tests --showlocals -s -vv -n=auto - name: Cleanup run: uv cache prune --ci diff --git a/pyproject.toml b/pyproject.toml index cc38719..bffe0d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dev = [ "pytest", "pytest-xdist", "pyright", + "setuptools", "trl", "wandb" ] From 78fa64390f6b93ec6f8bfa27864c75524fb97f72 Mon Sep 17 00:00:00 2001 From: Baber Date: Mon, 17 Nov 2025 18:58:13 +0500 Subject: [PATCH 4/9] nit --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5e5c4a4..101f107 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dev = [ "pyright", "setuptools", "trl", - "wandb" + "wandb", # Generate documentation "furo", "myst-parser", From 45841c295ab288be98f31ff6a2715943e2d14fc3 Mon Sep 17 00:00:00 2001 From: Baber Date: Tue, 18 Nov 2025 04:40:10 +0500 Subject: [PATCH 5/9] use ruff-format --- .pre-commit-config.yaml | 24 +++++++++++------------- pyproject.toml | 6 ++++-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b58386c..2f1d88d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,18 +1,16 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - exclude: \.txt$ -- repo: https://github.com/psf/black - rev: 25.1.0 + - id: trailing-whitespace + - id: end-of-file-fixer + exclude: \.txt$ + - id: no-commit-to-branch + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.5 hooks: - - id: black -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.11.9' - hooks: - - id: ruff - args: [--fix, --exit-non-zero-on-fix] + - id: ruff-check + args: [ --fix ] + - id: ruff-format diff --git a/pyproject.toml b/pyproject.toml index 101f107..9b02a51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,10 +55,12 @@ reportPrivateImportUsage = false include = ["bergson*"] [tool.ruff] -lint.ignore = ["E741"] # Ambiguous variable name +lint.ignore = ["E741", # Ambiguous variable name + "E501", # line-too-long (formatter takes care of it) +] # Enable pycodestyle (`E`), Pyflakes (`F`), and isort (`I`) codes # See https://beta.ruff.rs/docs/rules/ for more possible rules -lint.select = ["E", "F", "I"] +lint.extend-select = ["E", "F", "I"] # Same as Black. line-length = 88 From b2c1017673f429ca8e5c03158f8ab1f8a402a3ec Mon Sep 17 00:00:00 2001 From: Baber Date: Tue, 18 Nov 2025 04:59:28 +0500 Subject: [PATCH 6/9] fix typing --- .github/workflows/unit_tests.yml | 19 ++++++++++++------- bergson/__main__.py | 3 +-- bergson/gradients.py | 7 +++---- bergson/huggingface.py | 6 +++--- bergson/query/faiss_index.py | 6 +++--- tests/test_build.py | 12 ++++++------ tests/test_reduce.py | 6 +++--- tests/test_score.py | 6 +++--- 8 files changed, 34 insertions(+), 31 deletions(-) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 8db2d1a..9f94424 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -17,18 +17,23 @@ jobs: timeout-minutes: 5 steps: - - name: Pre-commit Checks - uses: actions/checkout@v5 - - name: Set up Python 3.10 + - uses: actions/checkout@v5 + - name: "Set up Python" uses: actions/setup-python@v6 with: - python-version: '3.10' - cache: pip - cache-dependency-path: pyproject.toml + python-version: "3.10" + - name: Checkout + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + - name: Install dependencies + run: uv pip install -e ".[dev]" --torch-backend=auto - name: Pre-Commit uses: pre-commit/action@v3.0.1 - name: Type Checking - uses: jakebailey/pyright-action@v1 + uses: jakebailey/pyright-action@v2 + - name: Cleanup + run: uv cache prune --ci testcpu: name: CPU Tests diff --git a/bergson/__main__.py b/bergson/__main__.py index 7b4c55c..789e9ea 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -57,8 +57,7 @@ def execute(self): if self.index_cfg.projection_dim != 0: print( - "Warning: projection_dim is not 0. " - "Compressed gradients will be scored." + "Warning: projection_dim is not 0. Compressed gradients will be scored." ) score_dataset(self.index_cfg, self.score_cfg) diff --git a/bergson/gradients.py b/bergson/gradients.py index 992a335..2753364 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -155,9 +155,9 @@ def to_adafactor(self) -> AdafactorNormalizer: and the factored second moments. """ # We assume avg_sq is a square matrix of shape [O, I] - assert ( - self.avg_sq.ndim == 2 - ), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D" + assert self.avg_sq.ndim == 2, ( + f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D" + ) # Compute row and column means return AdafactorNormalizer( @@ -563,7 +563,6 @@ def _process_grad(self, module: nn.Module, _, grad_out): # If we are using AdamNormalizer, or including bias gradients # we need to materialize the full gradient and then project if isinstance(norm, AdamNormalizer) or include_bias: - P = G.mT @ I # [N, O, S] @ [N, S, I] → [N, O, I] if include_bias: # Append the bias gradient to the input diff --git a/bergson/huggingface.py b/bergson/huggingface.py index 8dbb969..7c76a50 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -243,9 +243,9 @@ def on_step_end( # Record training order if enabled if self.order is not None: - assert ( - self.batch_indices is not None - ), "Batch indices are not available for training order tracking" + assert self.batch_indices is not None, ( + "Batch indices are not available for training order tracking" + ) epoch = int(state.epoch or 0) global_step = state.global_step diff --git a/bergson/query/faiss_index.py b/bergson/query/faiss_index.py index 58a1b65..073c5c9 100644 --- a/bergson/query/faiss_index.py +++ b/bergson/query/faiss_index.py @@ -242,9 +242,9 @@ def create_index( shard_sizes[-1] += remainder # Verify all gradients will be consumed - assert ( - sum(shard_sizes) == total_grads - ), f"Shard sizes {shard_sizes} don't sum to total_grads {total_grads}" + assert sum(shard_sizes) == total_grads, ( + f"Shard sizes {shard_sizes} don't sum to total_grads {total_grads}" + ) dl = gradients_loader(gradients_path) buffer: list[NDArray] = [] diff --git a/tests/test_build.py b/tests/test_build.py index 6319499..393c1c7 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -80,9 +80,9 @@ def test_split_attention_build(tmp_path: Path, model, dataset): attention_cfgs=attention_cfgs, ) - assert any( - Path(cfg.partial_run_path).iterdir() - ), "Expected artifacts in the temp run_path" + assert any(Path(cfg.partial_run_path).iterdir()), ( + "Expected artifacts in the temp run_path" + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -107,9 +107,9 @@ def test_conv1d_build(tmp_path: Path, dataset): cfg=cfg, ) - assert any( - Path(cfg.partial_run_path).iterdir() - ), "Expected artifacts in the run path" + assert any(Path(cfg.partial_run_path).iterdir()), ( + "Expected artifacts in the run path" + ) index = load_gradients(cfg.partial_run_path) diff --git a/tests/test_reduce.py b/tests/test_reduce.py index 6718a29..76ec0d6 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -31,6 +31,6 @@ def test_reduce_e2e(tmp_path: Path): text=True, ) - assert ( - "error" not in result.stderr.lower() - ), f"Error found in stderr: {result.stderr}" + assert "error" not in result.stderr.lower(), ( + f"Error found in stderr: {result.stderr}" + ) diff --git a/tests/test_score.py b/tests/test_score.py index 294d636..ee72730 100644 --- a/tests/test_score.py +++ b/tests/test_score.py @@ -64,9 +64,9 @@ def test_large_gradients_query(tmp_path: Path, dataset): ) assert result.returncode == 0 - assert ( - "error" not in result.stderr.lower() - ), f"Error found in stderr: {result.stderr}" + assert "error" not in result.stderr.lower(), ( + f"Error found in stderr: {result.stderr}" + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") From 40d05a46aa97907f6b71e07f17d79f5926a20fe9 Mon Sep 17 00:00:00 2001 From: Baber Date: Fri, 14 Nov 2025 21:46:27 +0500 Subject: [PATCH 7/9] fix: sum bias gradients over sequence dim only, not batch + tests --- bergson/gradients.py | 23 +++--- tests/test_gradients.py | 170 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 178 insertions(+), 15 deletions(-) diff --git a/bergson/gradients.py b/bergson/gradients.py index 2753364..c5c8005 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -334,8 +334,8 @@ class GradientCollector(ContextDecorator): of the parameters, which are expected to be precomputed and passed in. We assume that the input to `model` is of shape `[N, S, I]`, where `N` is the - batch size, `S` is the sequence length, and `I` is the input dimension. We take the - mean over the sequence length to obtain a single gradient per sequence. + batch size, `S` is the sequence length, and `I` is the input dimension. We + sum over the sequence dimension to obtain a single gradient per sequence. """ model: nn.Module @@ -564,24 +564,19 @@ def _process_grad(self, module: nn.Module, _, grad_out): # we need to materialize the full gradient and then project if isinstance(norm, AdamNormalizer) or include_bias: P = G.mT @ I # [N, O, S] @ [N, S, I] → [N, O, I] + if isinstance(norm, AdamNormalizer): + # Normalize the gradients using the second moment matrix + P /= norm.avg_sq.sqrt().add_(1e-8) + if include_bias: - # Append the bias gradient to the input + # TODO: should we normalize the bias gradients? + # Append the raw bias gradient to the input P = torch.cat( - [ - P, - G.sum(dim=(0, 1)) - .unsqueeze(0) - .unsqueeze(2) - .expand(P.shape[0], -1, 1), - ], + [P, G.sum(dim=1).unsqueeze(2)], # [N, S, O] -> [N, O] # [N, O, 1] dim=2, ) i += 1 - if isinstance(norm, AdamNormalizer): - # Normalize the gradients using the second moment matrix - P /= norm.avg_sq.sqrt().add_(1e-8) - if self.processor.reshape_to_square: P = reshape_to_nearest_square(P) o, i = P.shape[-2:] diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 19528c3..fa90e47 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,7 +1,10 @@ import tempfile +from collections import defaultdict from pathlib import Path +import pytest import torch +import torch.nn as nn from transformers import AutoConfig, AutoModelForCausalLM from bergson.gradients import ( @@ -13,7 +16,7 @@ ) -def test_phi3(): +def test_gradient_collector_proj_norm(): temp_dir = Path(tempfile.mkdtemp()) config = AutoConfig.from_pretrained("trl-internal-testing/tiny-Phi3ForCausalLM") @@ -105,3 +108,168 @@ def closure(name: str, g: torch.Tensor): ) previous_collected_grads = collected_grads.copy() + + +@pytest.mark.parametrize("include_bias", [True, False]) +def test_gradient_collector_batched(include_bias: bool): + torch.manual_seed(42) + N = 4 + S = 6 + I = 5 + O = 3 + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(I, O * 2, bias=include_bias) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(O * 2, O, bias=include_bias) + + def forward(self, x): + return self.fc2(self.relu(self.fc1(x))) + + torch.manual_seed(42) + model = SimpleModel() + + optimizer = torch.optim.Adam(model.parameters()) + + # Run a few training steps to build up second moments + for _ in range(5): + optimizer.zero_grad() + out = model(torch.randn(N, S, I)) + loss = (out**2).sum() + loss.backward() + optimizer.step() + + normalizers = {} + for name, param in model.named_parameters(): + if "weight" in name: + layer_name = name.replace(".weight", "") + # Adam stores second moments as 'exp_avg_sq' + exp_avg_sq = optimizer.state[param]["exp_avg_sq"] + normalizers[layer_name] = AdamNormalizer(exp_avg_sq) + + # collect gradients + collected_grads = {} + + def closure(name: str, g: torch.Tensor): + """Store the gradients in a dictionary for later comparison.""" + collected_grads[name] = g + + processor = GradientProcessor( + normalizers=normalizers, projection_dim=None, include_bias=include_bias + ) + collector = GradientCollector(model, closure, processor) + + x = torch.randn(N, S, I) + with collector: + model.zero_grad() + out = model(x) + loss = (out**2).sum() + loss.backward() + + def compute_ground_truth(): + """Compute gradients using individual backward passes, with normalization.""" + model.zero_grad() + output = model(x) # [N, S, O] + + # Per-sample losses + per_sample_losses = (output**2).sum(dim=(1, 2)) # [N] + + ground_truth_grads = defaultdict(list) + for n in range(N): + model.zero_grad() + per_sample_losses[n].backward(retain_graph=True) + + # manually normalize + for layer_name in ["fc1", "fc2"]: + layer = model.get_submodule(layer_name) + grad = layer.weight.grad.clone() + + grad = normalizers[layer_name].normalize_(grad) + + if include_bias: + bias_grad = layer.bias.grad.clone() + bias_grad = bias_grad.unsqueeze(1) + grad = torch.cat([grad, bias_grad], dim=1) + + ground_truth_grads[layer_name].append(grad) + + for layer_name in ["fc1", "fc2"]: + ground_truth_grads[layer_name] = torch.stack(ground_truth_grads[layer_name]) + + return ground_truth_grads + + ground_truth = compute_ground_truth() + for layer_name in ["fc1", "fc2"]: + torch.testing.assert_close( + collected_grads[layer_name], ground_truth[layer_name] + ) + + +def test_bias_gradients(): + """Test that per-sample bias gradients are correctly computed.""" + torch.manual_seed(42) + N = 4 + S = 6 + I = 5 + O = 3 + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(I, O, bias=True) + + def forward(self, x): + return self.fc(x) + + model = SimpleModel() + x = torch.randn(N, S, I) + + # bias gradient is a sum over sequence dimension for each n + def compute_ground_truth(model) -> torch.Tensor: + """Compute gradients using individual backward passes.""" + model.zero_grad() + output = model(x) # [N, S, O] + + per_sample_losses = (output**2).sum(dim=(1, 2)) # [N] + + bias_grads = [] + for n in range(N): + model.zero_grad() + per_sample_losses[n].backward(retain_graph=True) + bias_grads.append(model.fc.bias.grad.clone()) + + return torch.stack(bias_grads, dim=0) # [N, O] + + ground_truth = compute_ground_truth(model) + + # GradientCollector with include_bias=True + collected_grads = {} + + def closure(name: str, g: torch.Tensor): + collected_grads[name] = g + + processor = GradientProcessor(include_bias=True, projection_dim=None) + collector = GradientCollector(model, closure, processor, target_modules={"fc"}) + + with collector: + model.zero_grad() + output = model(x) + loss = (output**2).sum() + loss.backward() + + # the last column is bias + bias_grads = collected_grads["fc"][..., -1] + + assert bias_grads.shape == ( + N, + 3, + ), f"Expected shape ({N}, {O}), got {bias_grads.shape}" + assert ground_truth.shape == ( + N, + 3, + ), f"Expected shape ({N}, {O}), got {ground_truth.shape}" + + # Compare to ground truth + torch.testing.assert_close(bias_grads, ground_truth) From e5f68699b4e503c94d57261592baa7bf64c81a7c Mon Sep 17 00:00:00 2001 From: Baber Date: Tue, 18 Nov 2025 02:53:00 +0500 Subject: [PATCH 8/9] fix: add normalizer bias support. fix trainer callback. add tests --- bergson/gradients.py | 45 ++++++++-- bergson/huggingface.py | 79 +++++++++++++---- tests/test_trainer_callback.py | 155 ++++++++++++++++++++++++++++++++- 3 files changed, 257 insertions(+), 22 deletions(-) diff --git a/bergson/gradients.py b/bergson/gradients.py index c5c8005..9ad9a6d 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -68,14 +68,24 @@ def state_dict(self) -> dict[str, str | Tensor]: class AdafactorNormalizer(Normalizer): """ Row and column sums of second moments of gradients for a matrix-valued parameter. + + Args: + row: Row statistics [O] + col: Column statistics [I] + bias_avg_sq: Optional second moments for bias [O] """ row: Tensor # shape [O] col: Tensor # shape [I] + bias_avg_sq: Tensor | None = None # shape [O] def __post_init__(self): assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D" assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D" + if self.bias_avg_sq is not None: + assert self.bias_avg_sq.ndim == 1, ( + f"Expected 1D tensor for bias_avg_sq, got {self.bias_avg_sq.ndim}D" + ) @torch.compile def normalize_( @@ -120,22 +130,29 @@ def to_adam(self) -> "AdamNormalizer": """ Convert this Adafactor normalizer to an Adam normalizer by materializing the rank-one second moment matrix. + + Preserves bias_avg_sq if present. """ # Compute the second moment matrix as a square matrix of shape [O, I] # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to # add it outside the square root. This could cause infs though if there are # any exactly zero rows or columns, so we should be careful. avg_sq = torch.outer(self.row, self.col) / self.row.mean() - return AdamNormalizer(avg_sq=avg_sq) + return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq) @dataclass class AdamNormalizer(Normalizer): """ Contains the second moments of the gradients. + + Args: + avg_sq: Second moments for weights [O, I] + bias_avg_sq: Optional second moments for bias [O] """ avg_sq: Tensor + bias_avg_sq: Tensor | None = None @torch.compile def normalize_( @@ -153,6 +170,8 @@ def to_adafactor(self) -> AdafactorNormalizer: Convert this Adam normalizer to an Adafactor normalizer, minimizing the I-divergence (generalized Kullback-Leibler divergence) between the original and the factored second moments. + + Preserves bias_avg_sq if present. """ # We assume avg_sq is a square matrix of shape [O, I] assert self.avg_sq.ndim == 2, ( @@ -163,6 +182,7 @@ def to_adafactor(self) -> AdafactorNormalizer: return AdafactorNormalizer( row=self.avg_sq.mean(dim=1), # shape [O] col=self.avg_sq.mean(dim=0), # shape [I] + bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments ) @@ -551,8 +571,22 @@ def _process_grad(self, module: nn.Module, _, grad_out): i = getattr(module, LayerAdapter.in_attr(module)) o = getattr(module, LayerAdapter.out_attr(module)) - # Pre-scale G by the Adafactor row statistics + # Handle bias gradients if needed (must be computed from raw G) norm = self.processor.normalizers.get(name) + bias_grad = None + if include_bias: + # Compute bias from raw G (before any normalization) + bias_grad = G.sum(dim=1) # [N, S, O] -> [N, O] + + # Normalize bias with appropriate second moments + if ( + isinstance(norm, (AdamNormalizer, AdafactorNormalizer)) + and hasattr(norm, "bias_avg_sq") + and norm.bias_avg_sq is not None + ): + bias_grad = bias_grad / norm.bias_avg_sq.sqrt().add_(1e-8) + + # Pre-scale G by the Adafactor row statistics (for weight gradients) if isinstance(norm, AdafactorNormalizer): # Compare to the normalize_ method in AdafactorNormalizer r = norm.row.add(1e-30) @@ -568,11 +602,10 @@ def _process_grad(self, module: nn.Module, _, grad_out): # Normalize the gradients using the second moment matrix P /= norm.avg_sq.sqrt().add_(1e-8) - if include_bias: - # TODO: should we normalize the bias gradients? - # Append the raw bias gradient to the input + if include_bias and bias_grad is not None: + # Append pre-computed and normalized bias gradient P = torch.cat( - [P, G.sum(dim=1).unsqueeze(2)], # [N, S, O] -> [N, O] # [N, O, 1] + [P, bias_grad.unsqueeze(2)], # [N, O, 1] dim=2, ) i += 1 diff --git a/bergson/huggingface.py b/bergson/huggingface.py index 7c76a50..cb36758 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -239,7 +239,6 @@ def on_step_end( **kwargs, ): self.on_substep_end(args, state, control) - print("Step end") # Record training order if enabled if self.order is not None: @@ -279,32 +278,82 @@ def on_step_end( # Read normalizers off of the optimizer state. We need to figure out # what type of optimizer this is first. + # Collect references to both weight and bias second moments per layer + layer_second_moments: dict[str, dict[str, Tensor]] = {} + for group in optimizer.param_groups: - lr_sqrt = group["lr"] ** 0.5 + group_lr = group["lr"] for param in group["params"]: - name = param_to_name[param].removesuffix(".weight") - if name not in self.collector.target_info: + param_name = param_to_name[param] + + # Extract layer name (remove .weight or .bias suffix) + if param_name.endswith(".weight"): + param_type = "weight" + layer_name = param_name.removesuffix(".weight") + elif param_name.endswith(".bias"): + param_type = "bias" + layer_name = param_name.removesuffix(".bias") + else: + continue + + if layer_name not in self.collector.target_info: continue p_state = optimizer.state[param] + # Initialize layer dict if needed, storing this group's learning rate + if layer_name not in layer_second_moments: + layer_second_moments[layer_name] = {"lr": group_lr} + # Adam-like optimizer if (eas := p_state.get("exp_avg_sq")) is not None: - norm = AdamNormalizer(eas).to_adafactor() - + layer_second_moments[layer_name][param_type] = eas # Adafactor-like optimizer elif (vr := p_state.get("exp_avg_sq_row")) is not None: vc = p_state.get("exp_avg_sq_col") - norm = AdafactorNormalizer(vr, vc) - else: - continue - - # Scale the gradient by the current learning rate. It's factorized - # so we multiply each factor by the square root of the LR. - norm.row *= lr_sqrt - norm.col *= lr_sqrt - normalizers[name] = norm + if param_type == "weight": + # Factorized second moments for weights + layer_second_moments[layer_name]["row"] = vr + layer_second_moments[layer_name]["col"] = vc + elif param_type == "bias": + # Adafactor stores bias as regular exp_avg_sq + bias_eas = p_state.get("exp_avg_sq") + if bias_eas is not None: + layer_second_moments[layer_name]["bias"] = bias_eas + + # Build normalizers from collected second moments + for layer_name, moments in layer_second_moments.items(): + lr_sqrt = moments["lr"] ** 0.5 + + # Adam-like: has weight exp_avg_sq + if "weight" in moments: + weight_eas = moments["weight"] + bias_eas = moments.get("bias") # May be None + + # Create Adam normalizer with optional bias, then convert to Adafactor + # TODO: always convert to adafactor? + norm = AdamNormalizer(weight_eas, bias_eas).to_adafactor() + + # Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state + norm.row = norm.row * lr_sqrt + norm.col = norm.col * lr_sqrt + if norm.bias_avg_sq is not None: + norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2) + + # Adafactor-like: has row/col + elif "row" in moments and "col" in moments: + bias_eas = moments.get("bias") # May be present + norm = AdafactorNormalizer(moments["row"], moments["col"], bias_eas) + # Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state + norm.row = norm.row * lr_sqrt + norm.col = norm.col * lr_sqrt + if norm.bias_avg_sq is not None: + norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2) + else: + continue + + normalizers[layer_name] = norm proc.normalizers = normalizers diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py index 56f38cf..24984f5 100644 --- a/tests/test_trainer_callback.py +++ b/tests/test_trainer_callback.py @@ -1,4 +1,10 @@ import os +from pathlib import Path + +from torch import nn + +from bergson import GradientProcessor +from bergson.gradients import AdafactorNormalizer os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["WANDB_MODE"] = "disabled" @@ -6,7 +12,13 @@ import pytest import torch from datasets import Dataset -from transformers import AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments +from transformers import ( + Adafactor, + AutoConfig, + AutoModelForCausalLM, + Trainer, + TrainingArguments, +) from trl import SFTConfig, SFTTrainer from bergson.data import load_gradients @@ -245,3 +257,144 @@ def test_sft_trainer(self, tmp_path, model, dataset): saved_order = Dataset.load_from_disk(str(order_file)) assert len(saved_order) > 0 assert all(key in saved_order[0] for key in ["_idx", "global_step", "epoch"]) + + @pytest.mark.parametrize("optimizer_name", ["adam", "adafactor"]) + @pytest.mark.parametrize("include_bias", [True, False]) + def test_optimizer_state_extraction(self, optimizer_name: str, include_bias: bool): + """Test that normalizers are correctly extracted from optimizer state. + + This tests the huggingface.py callback by: + 1. Training a model with an optimizer + 2. Calling the callback's on_step_end method + 3. Verifying against raw optimizer state + """ + torch.manual_seed(42) + N = 4 + S = 6 + I = 5 + O = 3 + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(I, O * 2, bias=include_bias) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(O * 2, O, bias=include_bias) + + def forward(self, x): + return self.fc2(self.relu(self.fc1(x))) + + torch.manual_seed(42) + model = SimpleModel() + + # Create optimizer + if optimizer_name == "adam": + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + else: + optimizer = Adafactor( + model.parameters(), scale_parameter=False, relative_step=False, lr=0.001 + ) + + # Train a few steps to build up second moments + for _ in range(5): + optimizer.zero_grad() + out = model(torch.randn(N, S, I)) + loss = (out**2).sum() + loss.backward() + optimizer.step() + + # Extract normalizers using the ACTUAL callback + from unittest.mock import Mock, patch + + from bergson.huggingface import GradientCollectorCallback + + # Create callback with minimal setup + callback = GradientCollectorCallback( + path=Path("/tmp/test"), + use_optimizer_state=True, + include_bias=include_bias, + ) + + # Mock the collector and processor + mock_collector = Mock() + mock_collector.processor = GradientProcessor( + normalizers={}, include_bias=include_bias + ) + mock_collector.target_info = {"fc1": None, "fc2": None} # Track these layers + callback.collector = mock_collector + + # Mock on_substep_end to avoid needing train_grad_buffer + with patch.object(callback, "on_substep_end"): + # Call the ACTUAL callback method + callback.on_step_end( + args=Mock(), + state=Mock(epoch=0, global_step=1), + control=Mock(), + model=model, + optimizer=optimizer, + ) + + # Get the normalizers the callback extracted + normalizers = callback.collector.processor.normalizers + + # Verify against raw optimizer state (independent ground truth) + for layer_name in ["fc1", "fc2"]: + layer = model.get_submodule(layer_name) + norm = normalizers[layer_name] + + # Check normalizer type + assert isinstance(norm, AdafactorNormalizer) + + # Get raw state from optimizer + weight_state = optimizer.state[layer.weight] + lr = optimizer.param_groups[0]["lr"] + lr_sqrt = lr**0.5 + + if optimizer_name == "adam": + # Ground truth: Adam stores full exp_avg_sq + raw_exp_avg_sq = weight_state["exp_avg_sq"] + + # NOTE: We convert Adam's full second moments to Adafactor's factorized + # form (row + col vectors) for memory efficiency. This is a lossy + # rank-1 approximation that can have large reconstruction errors. + # We can't verify correctness here, only sanity check the factorization. + + # Sanity checks on the factorized representation + assert norm.row.shape == (raw_exp_avg_sq.shape[0],) + assert norm.col.shape == (raw_exp_avg_sq.shape[1],) + assert ( + not torch.isnan(norm.row).any() and not torch.isinf(norm.row).any() + ) + assert ( + not torch.isnan(norm.col).any() and not torch.isinf(norm.col).any() + ) + assert (norm.row > 0).all() and ( + norm.col > 0 + ).all() # Second moments are positive + + elif optimizer_name == "adafactor": + # Ground truth: Adafactor stores row/col directly + raw_row = weight_state["exp_avg_sq_row"] + raw_col = weight_state["exp_avg_sq_col"] + + # Our normalizer should match (scaled by LR) + expected_row = raw_row * lr_sqrt + expected_col = raw_col * lr_sqrt + + torch.testing.assert_close(norm.row, expected_row) + torch.testing.assert_close(norm.col, expected_col) + + # Verify bias handling + if include_bias and layer.bias is not None: + bias_state = optimizer.state[layer.bias] + raw_bias_exp_avg_sq = bias_state["exp_avg_sq"] + expected_bias = raw_bias_exp_avg_sq * lr + + assert norm.bias_avg_sq is not None, ( + f"Expected bias_avg_sq for {layer_name}" + ) + torch.testing.assert_close(norm.bias_avg_sq, expected_bias) + else: + assert norm.bias_avg_sq is None, ( + f"Unexpected bias_avg_sq for {layer_name}" + ) From 31e50084eb6a9b2c4a64fc2402767bb58192a88b Mon Sep 17 00:00:00 2001 From: Baber Date: Tue, 18 Nov 2025 03:48:02 +0500 Subject: [PATCH 9/9] add `scale_by_lr` method to AdafactorNormalizer --- bergson/gradients.py | 15 +++++++++++++++ bergson/huggingface.py | 21 +++++++-------------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/bergson/gradients.py b/bergson/gradients.py index 9ad9a6d..242023a 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -140,6 +140,21 @@ def to_adam(self) -> "AdamNormalizer": avg_sq = torch.outer(self.row, self.col) / self.row.mean() return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq) + def scale_by_lr(self, lr: float | Tensor) -> "AdafactorNormalizer": + """Scale normalizer by learning rate. + + Factorized dimensions (row, col) are scaled by sqrt(lr). + Bias is scaled by lr. + """ + lr_sqrt = lr**0.5 + return AdafactorNormalizer( + row=self.row * lr_sqrt, # shape [O] + col=self.col * lr_sqrt, # shape [I] + bias_avg_sq=self.bias_avg_sq * lr + if self.bias_avg_sq is not None + else None, # shape [O] + ) + @dataclass class AdamNormalizer(Normalizer): diff --git a/bergson/huggingface.py b/bergson/huggingface.py index cb36758..f780a8b 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -324,7 +324,7 @@ def on_step_end( # Build normalizers from collected second moments for layer_name, moments in layer_second_moments.items(): - lr_sqrt = moments["lr"] ** 0.5 + lr = moments["lr"] # Adam-like: has weight exp_avg_sq if "weight" in moments: @@ -333,23 +333,16 @@ def on_step_end( # Create Adam normalizer with optional bias, then convert to Adafactor # TODO: always convert to adafactor? - norm = AdamNormalizer(weight_eas, bias_eas).to_adafactor() - - # Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state - norm.row = norm.row * lr_sqrt - norm.col = norm.col * lr_sqrt - if norm.bias_avg_sq is not None: - norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2) + norm = ( + AdamNormalizer(weight_eas, bias_eas).to_adafactor().scale_by_lr(lr) + ) # Adafactor-like: has row/col elif "row" in moments and "col" in moments: bias_eas = moments.get("bias") # May be present - norm = AdafactorNormalizer(moments["row"], moments["col"], bias_eas) - # Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state - norm.row = norm.row * lr_sqrt - norm.col = norm.col * lr_sqrt - if norm.bias_avg_sq is not None: - norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2) + norm = AdafactorNormalizer( + moments["row"], moments["col"], bias_eas + ).scale_by_lr(lr) else: continue