Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize llama3 example #31

Merged
merged 7 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ For detailed description of the methodology, see the [**paper**](https://arxiv.o

---

> [!WARNING]
> This repository is under development and has not reached its first stable release.

## Installation

> [!IMPORTANT]
Expand Down
199 changes: 191 additions & 8 deletions examples/openwebtext/README.md

Large diffs are not rendered by default.

26 changes: 9 additions & 17 deletions examples/openwebtext/compute_scores.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import argparse
import logging
from datetime import timedelta
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F
from accelerate import Accelerator, InitProcessGroupKwargs
from torch import nn
from transformers import default_data_collator

from examples.openwebtext.pipeline import (
Expand All @@ -16,18 +13,11 @@
)
from examples.openwebtext.task import LanguageModelingTask
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.task import Task
from kronfluence.utils.common.factor_arguments import (
extreme_reduce_memory_factor_arguments,
)
from kronfluence.utils.common.score_arguments import (
all_low_precision_score_arguments,
extreme_reduce_memory_score_arguments,
)
from kronfluence.utils.dataset import DataLoaderKwargs

BATCH_TYPE = Dict[str, torch.Tensor]

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

Expand All @@ -36,9 +26,9 @@ def parse_args():
parser = argparse.ArgumentParser(description="Influence score computation on Openwebtext dataset.")

parser.add_argument(
"--factor_strategy",
"--factors_name",
type=str,
default="ekfac",
default="july_11",
help="Strategy to compute influence factors.",
)
parser.add_argument(
Expand All @@ -50,7 +40,7 @@ def parse_args():
parser.add_argument(
"--train_batch_size",
type=int,
default=4,
default=8,
help="Batch size for computing query gradients.",
)
parser.add_argument(
Expand Down Expand Up @@ -93,22 +83,24 @@ def main():
dataloader_kwargs = DataLoaderKwargs(num_workers=4, collate_fn=default_data_collator, pin_memory=True)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

scores_name = args.factor_strategy
scores_name = args.factors_name
rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
# We set the damping term used for LLMs.
score_args = extreme_reduce_memory_score_arguments(
damping_factor=None, module_partitions=1, query_gradient_low_rank=rank, dtype=torch.bfloat16
)
# score_args.module_partitions = 2
score_args.query_gradient_accumulation_steps = 10
# We can invest some time in getting more accurate SVD results.
score_args.use_full_svd = True
analyzer.compute_pairwise_scores(
scores_name=scores_name,
score_args=score_args,
factors_name=args.factor_strategy,
factors_name=args.factors_name,
query_dataset=eval_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=1,
per_device_train_batch_size=args.train_batch_size,
overwrite_output_dir=False,
overwrite_output_dir=True,
)
scores = analyzer.load_pairwise_scores(scores_name)["all_modules"]
logging.info(f"Scores shape: {scores.shape}")
Expand Down
Binary file added examples/openwebtext/figure/eigenvalue.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/openwebtext/figure/lambda_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2,212 changes: 2,212 additions & 0 deletions examples/openwebtext/files/canada.txt

Large diffs are not rendered by default.

2,212 changes: 2,212 additions & 0 deletions examples/openwebtext/files/database.txt

Large diffs are not rendered by default.

1,784 changes: 1,784 additions & 0 deletions examples/openwebtext/files/doctor.txt

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions examples/openwebtext/files/factor_arguments.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"strategy": "ekfac",
"use_empirical_fisher": false,
"amp_dtype": "torch.bfloat16",
"amp_scale": 65536.0,
"has_shared_parameters": false,
"covariance_max_examples": 100000,
"covariance_data_partitions": 1,
"covariance_module_partitions": 2,
"activation_covariance_dtype": "torch.bfloat16",
"gradient_covariance_dtype": "torch.bfloat16",
"eigendecomposition_dtype": "torch.float64",
"lambda_max_examples": 100000,
"lambda_data_partitions": 1,
"lambda_module_partitions": 4,
"use_iterative_lambda_aggregation": true,
"offload_activations_to_cpu": true,
"per_sample_gradient_dtype": "torch.bfloat16",
"lambda_dtype": "torch.bfloat16"
}
2,112 changes: 2,112 additions & 0 deletions examples/openwebtext/files/inflation.txt

Large diffs are not rendered by default.

2,272 changes: 2,272 additions & 0 deletions examples/openwebtext/files/ml.txt

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions examples/openwebtext/files/query_dataset_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "Dataset",
"dataset_size": 5,
"indices": null
}
19 changes: 19 additions & 0 deletions examples/openwebtext/files/score_arguments.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"damping_factor": null,
"amp_dtype": "torch.bfloat16",
"offload_activations_to_cpu": true,
"data_partitions": 1,
"module_partitions": 1,
"compute_per_module_scores": false,
"compute_per_token_scores": false,
"query_gradient_accumulation_steps": 10,
"query_gradient_low_rank": 64,
"use_full_svd": true,
"aggregate_query_gradients": false,
"aggregate_train_gradients": false,
"use_measurement_for_self_influence": false,
"query_gradient_svd_dtype": "torch.float32",
"per_sample_gradient_dtype": "torch.bfloat16",
"precondition_dtype": "torch.bfloat16",
"score_dtype": "torch.bfloat16"
}
5 changes: 5 additions & 0 deletions examples/openwebtext/files/train_dataset_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "Dataset",
"dataset_size": 100000,
"indices": null
}
5 changes: 0 additions & 5 deletions examples/openwebtext/fit_factors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import argparse
import logging
from datetime import timedelta
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F
from accelerate import Accelerator, InitProcessGroupKwargs
from torch import nn
from transformers import default_data_collator

from examples.openwebtext.pipeline import construct_llama3, get_openwebtext_dataset
Expand All @@ -17,8 +14,6 @@
)
from kronfluence.utils.dataset import DataLoaderKwargs

BATCH_TYPE = Dict[str, torch.Tensor]

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

Expand Down
39 changes: 39 additions & 0 deletions examples/openwebtext/inpsect_factors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm
from tueplots import markers

from kronfluence.analyzer import Analyzer


def main():
plt.rcParams.update({"figure.dpi": 150})
plt.rcParams.update(markers.with_edge())
plt.rcParams["axes.axisbelow"] = True

layer_num = 18
module_name = f"model.layers.{layer_num}.mlp.down_proj"
# module_name = f"model.layers.{layer_num}.mlp.up_proj"
lambda_processed = Analyzer.load_file("num_lambda_processed.safetensors")[module_name]
lambda_matrix = Analyzer.load_file("lambda_matrix.safetensors")[module_name]
lambda_matrix.div_(lambda_processed)
lambda_matrix = lambda_matrix.float()
plt.matshow(lambda_matrix, cmap="PuBu", norm=LogNorm())

plt.title(module_name)
plt.colorbar()
plt.show()
plt.clf()

lambda_matrix = lambda_matrix.view(-1).numpy()
sorted_lambda_matrix = np.sort(lambda_matrix)
plt.plot(sorted_lambda_matrix)
plt.title(module_name)
plt.grid()
plt.yscale("log")
plt.ylabel("Eigenvalues")
plt.show()


if __name__ == "__main__":
main()
42 changes: 42 additions & 0 deletions examples/openwebtext/inspect_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer

from examples.openwebtext.pipeline import (
MODEL_NAME,
get_custom_dataset,
get_openwebtext_dataset,
)
from kronfluence.analyzer import Analyzer


def main():
scores = Analyzer.load_file("influence_results/scores_jul_11_2024/pairwise_scores.safetensors")["all_modules"].float()

train_dataset = get_openwebtext_dataset()
eval_dataset = get_custom_dataset()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)

eval_idx = 4
sorted_scores = torch.sort(scores[eval_idx], descending=True)
top_indices = sorted_scores.indices

plt.plot(sorted_scores.values)
plt.grid()
plt.ylabel("IF Score")
plt.show()

print("Query Sequence:")
print(
"Prompt: " + eval_dataset[eval_idx]["prompt"] + "; Completion: " + eval_dataset[eval_idx]["completion"] + "\n"
)

print("Top Influential Sequences:")
for i in range(100):
print("=" * 80)
print(f"Rank = {i}; Score = {scores[eval_idx][int(top_indices[i])].item()}")
print(tokenizer.decode(train_dataset[int(top_indices[i])]["input_ids"]))


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions examples/openwebtext/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
transformers
datasets
matplotlib
tueplots
4 changes: 4 additions & 0 deletions kronfluence/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def __post_init__(self) -> None:
):
raise ValueError("All data and module partitions must be positive.")

# For backward compatibility:
if not hasattr(self, "amp_scale"):
self.amp_scale = 2.0**16


@dataclass
class ScoreArguments(Arguments):
Expand Down
7 changes: 4 additions & 3 deletions kronfluence/factor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ACTIVATION_EIGENVECTORS_NAME,
GRADIENT_EIGENVALUES_NAME,
GRADIENT_EIGENVECTORS_NAME,
HEURISTIC_DAMPING_SCALE,
LAMBDA_MATRIX_NAME,
NUM_LAMBDA_PROCESSED,
)
Expand Down Expand Up @@ -199,7 +200,7 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device)
lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device))
damping_factor = score_args.damping_factor
if damping_factor is None:
damping_factor = 0.1 * torch.mean(lambda_matrix)
damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix)
lambda_matrix.add_(damping_factor)
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous()
storage[NUM_LAMBDA_PROCESSED] = None
Expand Down Expand Up @@ -259,7 +260,7 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device)
lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0)
damping_factor = score_args.damping_factor
if damping_factor is None:
damping_factor = 0.1 * torch.mean(lambda_matrix)
damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix)
lambda_matrix.add_(damping_factor)
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous()
storage[NUM_LAMBDA_PROCESSED] = None
Expand Down Expand Up @@ -328,7 +329,7 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device)
lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device))
damping_factor = score_args.damping_factor
if damping_factor is None:
damping_factor = 0.1 * torch.mean(lambda_matrix)
damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix)
lambda_matrix.add_(damping_factor)
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous()
storage[NUM_LAMBDA_PROCESSED] = None
Expand Down
2 changes: 2 additions & 0 deletions kronfluence/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# The total iteration step to synchronize the process when using distributed setting.
DISTRIBUTED_SYNC_INTERVAL = 1_000

HEURISTIC_DAMPING_SCALE = 0.1

# Activation covariance matrix.
ACTIVATION_COVARIANCE_MATRIX_NAME = "activation_covariance"
# Pseudo-gradient covariance matrix.
Expand Down
1 change: 1 addition & 0 deletions tests/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def test_default_factor_arguments() -> None:
assert factor_args.strategy == "ekfac"
assert factor_args.use_empirical_fisher is False
assert factor_args.amp_dtype is None
assert factor_args.amp_scale == 2.0**16
assert factor_args.has_shared_parameters is False

assert factor_args.covariance_max_examples == 100_000
Expand Down
Loading