|
1 | 1 | import json |
| 2 | +import random |
2 | 3 | from abc import ABC, abstractmethod |
3 | 4 | from contextlib import ContextDecorator |
4 | 5 | from dataclasses import asdict, astuple, dataclass, field |
5 | 6 | from pathlib import Path |
6 | 7 | from typing import Callable, Literal, Mapping |
7 | 8 |
|
8 | 9 | import torch |
| 10 | +import torch.distributed as dist |
9 | 11 | import torch.nn as nn |
| 12 | +from datasets import Dataset |
10 | 13 | from torch import Tensor |
11 | 14 | from torch.utils.hooks import RemovableHandle |
| 15 | +from tqdm.auto import tqdm |
| 16 | +from transformers import PreTrainedModel |
12 | 17 | from transformers.pytorch_utils import Conv1D as HFConv1D |
13 | 18 |
|
14 | 19 | from .config import AttentionConfig |
| 20 | +from .data import pad_and_tensor |
15 | 21 | from .math import reshape_to_nearest_square |
16 | 22 | from .utils import assert_type, create_projection_matrix |
17 | 23 |
|
@@ -64,108 +70,6 @@ def state_dict(self) -> dict[str, str | Tensor]: |
64 | 70 | } |
65 | 71 |
|
66 | 72 |
|
67 | | -@dataclass |
68 | | -class AdafactorNormalizer(Normalizer): |
69 | | - """ |
70 | | - Row and column sums of second moments of gradients for a matrix-valued parameter. |
71 | | - """ |
72 | | - |
73 | | - row: Tensor # shape [O] |
74 | | - col: Tensor # shape [I] |
75 | | - |
76 | | - def __post_init__(self): |
77 | | - assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D" |
78 | | - assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D" |
79 | | - |
80 | | - @torch.compile |
81 | | - def normalize_( |
82 | | - self, |
83 | | - grad: Tensor, |
84 | | - eps: float = 1e-30, |
85 | | - ) -> Tensor: |
86 | | - """ |
87 | | - Normalize the row and column sums by adding a small epsilon. |
88 | | -
|
89 | | - Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They |
90 | | - recommend 1e-30, but we use 1e-16 for extra numerical stability. |
91 | | - """ |
92 | | - # We follow the Adafactor implementation in the tensor2tensor repo, which is |
93 | | - # different from the paper and from the PyTorch implementation. First add eps |
94 | | - # to ensure these second moments are sufficiently far from zero. Then we don't |
95 | | - # need to worry about numerical stability anywhere else, and we don't need to |
96 | | - # materialize the outer product at any point. |
97 | | - r, c = self.row.add(eps), self.col.add(eps) |
98 | | - |
99 | | - # This is the denominator for V, the rank-one matrix of second moment estimates: |
100 | | - # V = torch.outer(r, c) / denom |
101 | | - # V_ij = r_i * c_j / denom |
102 | | - # But we want to (implicitly) take the Hadamard product with the elementwise |
103 | | - # reciprocal square root of V: |
104 | | - # (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt() |
105 | | - denom = r.mean() |
106 | | - |
107 | | - # Hadamard product with a rank-one matrix ab^T is the same as left-multiplying |
108 | | - # by diag(a) and right-multiplying by diag(b). In this case we can represent |
109 | | - # the elementwise reciprocal square root of V as ab^T where: |
110 | | - # a = denom.sqrt() * r.rsqrt() and b = c.rsqrt() |
111 | | - a = denom.sqrt() * r.rsqrt_() # shape [O] |
112 | | - b = c.rsqrt_() |
113 | | - |
114 | | - # Implicitly do the Hadamard product |
115 | | - grad *= a[:, None] # [N, O] * [O] → [N, O] |
116 | | - grad *= b[None, :] |
117 | | - return grad |
118 | | - |
119 | | - def to_adam(self) -> "AdamNormalizer": |
120 | | - """ |
121 | | - Convert this Adafactor normalizer to an Adam normalizer by materializing the |
122 | | - rank-one second moment matrix. |
123 | | - """ |
124 | | - # Compute the second moment matrix as a square matrix of shape [O, I] |
125 | | - # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to |
126 | | - # add it outside the square root. This could cause infs though if there are |
127 | | - # any exactly zero rows or columns, so we should be careful. |
128 | | - avg_sq = torch.outer(self.row, self.col) / self.row.mean() |
129 | | - return AdamNormalizer(avg_sq=avg_sq) |
130 | | - |
131 | | - |
132 | | -@dataclass |
133 | | -class AdamNormalizer(Normalizer): |
134 | | - """ |
135 | | - Contains the second moments of the gradients. |
136 | | - """ |
137 | | - |
138 | | - avg_sq: Tensor |
139 | | - |
140 | | - @torch.compile |
141 | | - def normalize_( |
142 | | - self, |
143 | | - grad: Tensor, |
144 | | - eps: float = 1e-8, |
145 | | - ) -> Tensor: |
146 | | - """Normalize the gradients by the square root of the second moments.""" |
147 | | - # Adam-style epsilon is added outside the square root |
148 | | - denom = self.avg_sq.sqrt() |
149 | | - return grad.div_(denom.add_(eps)) |
150 | | - |
151 | | - def to_adafactor(self) -> AdafactorNormalizer: |
152 | | - """ |
153 | | - Convert this Adam normalizer to an Adafactor normalizer, minimizing the |
154 | | - I-divergence (generalized Kullback-Leibler divergence) between the original |
155 | | - and the factored second moments. |
156 | | - """ |
157 | | - # We assume avg_sq is a square matrix of shape [O, I] |
158 | | - assert ( |
159 | | - self.avg_sq.ndim == 2 |
160 | | - ), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D" |
161 | | - |
162 | | - # Compute row and column means |
163 | | - return AdafactorNormalizer( |
164 | | - row=self.avg_sq.mean(dim=1), # shape [O] |
165 | | - col=self.avg_sq.mean(dim=0), # shape [I] |
166 | | - ) |
167 | | - |
168 | | - |
169 | 73 | @dataclass |
170 | 74 | class GradientProcessor: |
171 | 75 | """Configuration for processing and compressing gradients.""" |
@@ -626,3 +530,194 @@ def __exit__(self, exc_type, exc, tb): |
626 | 530 | self._bwd_hooks.clear() |
627 | 531 |
|
628 | 532 | return False |
| 533 | + |
| 534 | + |
| 535 | +@dataclass |
| 536 | +class AdafactorNormalizer(Normalizer): |
| 537 | + """ |
| 538 | + Row and column sums of second moments of gradients for a matrix-valued parameter. |
| 539 | + """ |
| 540 | + |
| 541 | + row: Tensor # shape [O] |
| 542 | + col: Tensor # shape [I] |
| 543 | + |
| 544 | + def __post_init__(self): |
| 545 | + assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D" |
| 546 | + assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D" |
| 547 | + |
| 548 | + @torch.compile |
| 549 | + def normalize_( |
| 550 | + self, |
| 551 | + grad: Tensor, |
| 552 | + eps: float = 1e-30, |
| 553 | + ) -> Tensor: |
| 554 | + """ |
| 555 | + Normalize the row and column sums by adding a small epsilon. |
| 556 | +
|
| 557 | + Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They |
| 558 | + recommend 1e-30, but we use 1e-16 for extra numerical stability. |
| 559 | + """ |
| 560 | + # We follow the Adafactor implementation in the tensor2tensor repo, which is |
| 561 | + # different from the paper and from the PyTorch implementation. First add eps |
| 562 | + # to ensure these second moments are sufficiently far from zero. Then we don't |
| 563 | + # need to worry about numerical stability anywhere else, and we don't need to |
| 564 | + # materialize the outer product at any point. |
| 565 | + r, c = self.row.add(eps), self.col.add(eps) |
| 566 | + |
| 567 | + # This is the denominator for V, the rank-one matrix of second moment estimates: |
| 568 | + # V = torch.outer(r, c) / denom |
| 569 | + # V_ij = r_i * c_j / denom |
| 570 | + # But we want to (implicitly) take the Hadamard product with the elementwise |
| 571 | + # reciprocal square root of V: |
| 572 | + # (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt() |
| 573 | + denom = r.mean() |
| 574 | + |
| 575 | + # Hadamard product with a rank-one matrix ab^T is the same as left-multiplying |
| 576 | + # by diag(a) and right-multiplying by diag(b). In this case we can represent |
| 577 | + # the elementwise reciprocal square root of V as ab^T where: |
| 578 | + # a = denom.sqrt() * r.rsqrt() and b = c.rsqrt() |
| 579 | + a = denom.sqrt() * r.rsqrt_() # shape [O] |
| 580 | + b = c.rsqrt_() |
| 581 | + |
| 582 | + # Implicitly do the Hadamard product |
| 583 | + grad *= a[:, None] # [N, O] * [O] → [N, O] |
| 584 | + grad *= b[None, :] |
| 585 | + return grad |
| 586 | + |
| 587 | + def to_adam(self) -> "AdamNormalizer": |
| 588 | + """ |
| 589 | + Convert this Adafactor normalizer to an Adam normalizer by materializing the |
| 590 | + rank-one second moment matrix. |
| 591 | + """ |
| 592 | + # Compute the second moment matrix as a square matrix of shape [O, I] |
| 593 | + # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to |
| 594 | + # add it outside the square root. This could cause infs though if there are |
| 595 | + # any exactly zero rows or columns, so we should be careful. |
| 596 | + avg_sq = torch.outer(self.row, self.col) / self.row.mean() |
| 597 | + return AdamNormalizer(avg_sq=avg_sq) |
| 598 | + |
| 599 | + |
| 600 | +@dataclass |
| 601 | +class AdamNormalizer(Normalizer): |
| 602 | + """ |
| 603 | + Contains the second moments of the gradients. |
| 604 | + """ |
| 605 | + |
| 606 | + avg_sq: Tensor |
| 607 | + |
| 608 | + @torch.compile |
| 609 | + def normalize_( |
| 610 | + self, |
| 611 | + grad: Tensor, |
| 612 | + eps: float = 1e-8, |
| 613 | + ) -> Tensor: |
| 614 | + """Normalize the gradients by the square root of the second moments.""" |
| 615 | + # Adam-style epsilon is added outside the square root |
| 616 | + denom = self.avg_sq.sqrt() |
| 617 | + return grad.div_(denom.add_(eps)) |
| 618 | + |
| 619 | + def to_adafactor(self) -> AdafactorNormalizer: |
| 620 | + """ |
| 621 | + Convert this Adam normalizer to an Adafactor normalizer, minimizing the |
| 622 | + I-divergence (generalized Kullback-Leibler divergence) between the original |
| 623 | + and the factored second moments. |
| 624 | + """ |
| 625 | + # We assume avg_sq is a square matrix of shape [O, I] |
| 626 | + assert ( |
| 627 | + self.avg_sq.ndim == 2 |
| 628 | + ), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D" |
| 629 | + |
| 630 | + # Compute row and column means |
| 631 | + return AdafactorNormalizer( |
| 632 | + row=self.avg_sq.mean(dim=1), # shape [O] |
| 633 | + col=self.avg_sq.mean(dim=0), # shape [I] |
| 634 | + ) |
| 635 | + |
| 636 | + |
| 637 | +def fit_normalizers( |
| 638 | + model: PreTrainedModel, |
| 639 | + data: Dataset, |
| 640 | + batches: list[list[int]], |
| 641 | + *, |
| 642 | + kind: Literal["adafactor", "adam"] = "adafactor", |
| 643 | + target_modules: set[str] | None = None, |
| 644 | +) -> dict[str, Normalizer]: |
| 645 | + """ |
| 646 | + Estimate the second moments of the model's gradients using a subset of the dataset. |
| 647 | + """ |
| 648 | + normalizers: dict[str, Normalizer] = {} |
| 649 | + rank = dist.get_rank() if dist.is_initialized() else 0 |
| 650 | + |
| 651 | + # Just to make the pbar more accurate |
| 652 | + rng = random.Random(0) |
| 653 | + rng.shuffle(batches) |
| 654 | + |
| 655 | + def adafactor_update(name: str, g: torch.Tensor): |
| 656 | + # We follow the tensor2tensor implementation of Adafactor, which |
| 657 | + # takes the mean rather than summing over the rows and columns. |
| 658 | + # row: mean over columns, shape [O] |
| 659 | + sq = g.float().square_().sum(0) |
| 660 | + row_acc = sq.mean(dim=1) |
| 661 | + # col: mean over rows, shape [I] |
| 662 | + col_acc = sq.mean(dim=0) |
| 663 | + |
| 664 | + if (normalizer := normalizers.get(name)) is None: |
| 665 | + # initialize accumulators at zero |
| 666 | + normalizers[name] = normalizer = AdafactorNormalizer( |
| 667 | + torch.zeros_like(row_acc), |
| 668 | + torch.zeros_like(col_acc), |
| 669 | + ) |
| 670 | + else: |
| 671 | + assert isinstance(normalizer, AdafactorNormalizer) |
| 672 | + |
| 673 | + # in‐place accumulate |
| 674 | + normalizer.row.add_(row_acc) |
| 675 | + normalizer.col.add_(col_acc) |
| 676 | + |
| 677 | + def adam_update(name: str, g: torch.Tensor): |
| 678 | + sq = g.square_().float().sum(0) |
| 679 | + |
| 680 | + # initialize accumulators at zero |
| 681 | + if (normalizer := normalizers.get(name)) is None: |
| 682 | + normalizers[name] = normalizer = AdamNormalizer(torch.zeros_like(sq)) |
| 683 | + else: |
| 684 | + assert isinstance(normalizer, AdamNormalizer) |
| 685 | + |
| 686 | + # in‐place accumulate |
| 687 | + normalizer.avg_sq.add_(sq) |
| 688 | + |
| 689 | + callback = adafactor_update if kind == "adafactor" else adam_update |
| 690 | + |
| 691 | + for indices in tqdm(batches, disable=rank != 0, desc="Estimating normalizers"): |
| 692 | + batch = data[indices] |
| 693 | + |
| 694 | + with GradientCollector( |
| 695 | + model.base_model, |
| 696 | + closure=callback, |
| 697 | + target_modules=target_modules, |
| 698 | + ): |
| 699 | + x, y = pad_and_tensor( |
| 700 | + batch["input_ids"], # type: ignore |
| 701 | + labels=batch.get("labels", None), # type: ignore |
| 702 | + device=model.device, |
| 703 | + ) |
| 704 | + model(x, labels=y).loss.backward() |
| 705 | + model.zero_grad() |
| 706 | + |
| 707 | + # Divide by the number of documents processed and average across all ranks |
| 708 | + for normalizer in normalizers.values(): |
| 709 | + if isinstance(normalizer, AdamNormalizer): |
| 710 | + normalizer.avg_sq.div_(len(data)) |
| 711 | + |
| 712 | + if dist.is_initialized(): |
| 713 | + dist.all_reduce(normalizer.avg_sq, op=dist.ReduceOp.AVG) |
| 714 | + |
| 715 | + elif isinstance(normalizer, AdafactorNormalizer): |
| 716 | + normalizer.row.div_(len(data)) |
| 717 | + normalizer.col.div_(len(data)) |
| 718 | + |
| 719 | + if dist.is_initialized(): |
| 720 | + dist.all_reduce(normalizer.row, op=dist.ReduceOp.AVG) |
| 721 | + dist.all_reduce(normalizer.col, op=dist.ReduceOp.AVG) |
| 722 | + |
| 723 | + return normalizers |
0 commit comments