Skip to content

Commit 7c59820

Browse files
committed
Reduce query indices
1 parent 3655ea7 commit 7c59820

File tree

7 files changed

+71
-150
lines changed

7 files changed

+71
-150
lines changed

examples/uci/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This directory contains scripts designed for training a regression model and con
44

55
## Training
66

7-
To initiate the training of a regression model using the Concrete dataset, execute the following command:
7+
To train a regression model on the Concrete dataset, run the following command:
88
```bash
99
python train.py --dataset_name concrete \
1010
--dataset_dir ./data \
@@ -16,7 +16,6 @@ python train.py --dataset_name concrete \
1616
--num_train_epochs 20 \
1717
--seed 1004
1818
```
19-
Alternatively, you can download the model checkpoint.
2019

2120
# Influence Analysis
2221

examples/uci/analyze.py

Lines changed: 9 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
import logging
33
import math
44
import os
5-
from typing import Dict, Tuple
5+
from typing import Tuple
66

77
import torch
88
import torch.nn.functional as F
99
from analyzer import Analyzer, prepare_model
1010
from arguments import FactorArguments, ScoreArguments
11-
from module.utils import wrap_tracked_modules
1211
from task import Task
1312
from torch import nn
1413
from torch.profiler import ProfilerActivity, profile, record_function
@@ -96,14 +95,12 @@ def compute_measurement(
9695

9796
def main():
9897
args = parse_args()
99-
10098
logging.basicConfig(level=logging.INFO)
10199

102-
train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
103-
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir)
100+
train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", dataset_dir=args.dataset_dir)
101+
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", dataset_dir=args.dataset_dir)
104102

105103
model = construct_regression_mlp()
106-
107104
checkpoint_path = os.path.join(args.checkpoint_dir, "model.pth")
108105
if not os.path.isfile(checkpoint_path):
109106
raise ValueError(f"No checkpoint found at {checkpoint_path}.")
@@ -120,91 +117,25 @@ def main():
120117
)
121118
factor_args = FactorArguments(
122119
strategy=args.factor_strategy,
123-
covariance_data_partition_size=5,
124-
covariance_module_partition_size=4,
125120
)
126-
# with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
127-
# with record_function("covariance"):
128-
# analyzer.fit_covariance_matrices(
129-
# factors_name=args.factor_strategy,
130-
# dataset=train_dataset,
131-
# factor_args=factor_args,
132-
# per_device_batch_size=args.batch_size,
133-
# overwrite_output_dir=True,
134-
# )
135-
#
136-
# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
137-
# cov_factors = analyzer.fit_covariance_matrices(
138-
# factors_name=args.factor_strategy,
139-
# dataset=train_dataset,
140-
# factor_args=factor_args,
141-
# per_device_batch_size=args.batch_size,
142-
# overwrite_output_dir=True,
143-
# )
144-
# print(cov_factors)
145-
146-
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
147-
with record_function("eigen"):
148-
res = analyzer.perform_eigendecomposition(
149-
factors_name=args.factor_strategy,
150-
factor_args=factor_args,
151-
overwrite_output_dir=True,
152-
)
153-
# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
154-
# print(res)
155-
res = analyzer.fit_lambda_matrices(
121+
analyzer.fit_all_factors(
156122
factors_name=args.factor_strategy,
157123
dataset=train_dataset,
158-
# factor_args=factor_args,
159124
per_device_batch_size=None,
125+
factor_args=factor_args,
160126
overwrite_output_dir=True,
161127
)
162-
# print(res)
163-
#
164-
score_args = ScoreArguments(data_partition_size=2, module_partition_size=2)
165-
analyzer.compute_pairwise_scores(
166-
scores_name="hello",
128+
129+
scores = analyzer.compute_pairwise_scores(
130+
scores_name="pairwise",
167131
factors_name=args.factor_strategy,
168132
query_dataset=eval_dataset,
169133
train_dataset=train_dataset,
170134
per_device_query_batch_size=16,
171135
per_device_train_batch_size=8,
172-
score_args=score_args,
173136
overwrite_output_dir=True,
174137
)
175-
# scores = analyzer.load_pairwise_scores(scores_name="hello")
176-
# print(scores)
177-
#
178-
# analyzer.compute_self_scores(
179-
# scores_name="hello",
180-
# factors_name=args.factor_strategy,
181-
# # query_dataset=eval_dataset,
182-
# train_dataset=train_dataset,
183-
# # per_device_query_batch_size=16,
184-
# per_device_train_batch_size=8,
185-
# overwrite_output_dir=True,
186-
# )
187-
# # scores = analyzer.load_self_scores(scores_name="hello")
188-
# # print(scores)
189-
190-
# analyzer.fit_all_factors(
191-
# factor_name=args.factor_strategy,
192-
# dataset=train_dataset,
193-
# factor_args=factor_args,
194-
# per_device_batch_size=None,
195-
# overwrite_output_dir=True,
196-
# )
197-
#
198-
# score_name = "full_pairwise"
199-
# analyzer.compute_pairwise_scores(
200-
# score_name=score_name,
201-
# query_dataset=eval_dataset,
202-
# per_device_query_batch_size=len(eval_dataset),
203-
# train_dataset=train_dataset,
204-
# per_device_train_batch_size=len(train_dataset),
205-
# )
206-
# scores = analyzer.load_pairwise_scores(score_name=score_name)
207-
# print(scores.shape)
138+
logging.info(f"Scores: {scores}")
208139

209140

210141
if __name__ == "__main__":

examples/uci/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def get_regression_dataset(
3838
data_name: str,
3939
split: str,
4040
indices: List[int] = None,
41-
data_path: str = "data/",
41+
dataset_dir: str = "data/",
4242
) -> Dataset:
4343
assert split in ["train", "eval_train", "valid"]
4444

4545
# Load the dataset from the `.data` file.
46-
data = np.loadtxt(os.path.join(data_path, data_name + ".data"), delimiter=None)
46+
data = np.loadtxt(os.path.join(dataset_dir, data_name + ".data"), delimiter=None)
4747
data = data.astype(np.float32)
4848

4949
# Shuffle the dataset.

examples/uci/train.py

Lines changed: 42 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import argparse
22
import logging
33
import os
4-
from torch.utils import data
4+
55
import torch
66
import torch.nn.functional as F
7-
from torch import nn
87
from accelerate.utils import set_seed
8+
from torch import nn
9+
from torch.utils import data
910
from tqdm import tqdm
1011

1112
from examples.uci.pipeline import construct_regression_mlp, get_regression_dataset
@@ -82,7 +83,13 @@ def parse_args():
8283
return args
8384

8485

85-
def train(dataset: data.Dataset, batch_size: int, num_train_epochs: int, learning_rate: float, weight_decay: float) -> nn.Module:
86+
def train(
87+
dataset: data.Dataset,
88+
batch_size: int,
89+
num_train_epochs: int,
90+
learning_rate: float,
91+
weight_decay: float,
92+
) -> nn.Module:
8693
train_dataloader = data.DataLoader(
8794
dataset=dataset,
8895
batch_size=batch_size,
@@ -110,6 +117,25 @@ def train(dataset: data.Dataset, batch_size: int, num_train_epochs: int, learnin
110117
return model
111118

112119

120+
def evaluate(model: nn.Module, dataset: data.Dataset, batch_size: int) -> float:
121+
dataloader = data.DataLoader(
122+
dataset=dataset,
123+
batch_size=batch_size,
124+
shuffle=False,
125+
drop_last=False,
126+
)
127+
128+
model.eval()
129+
total_loss = 0
130+
for batch in dataloader:
131+
with torch.no_grad():
132+
inputs, targets = batch
133+
outputs = model(inputs)
134+
loss = F.mse_loss(outputs, targets, reduction="sum")
135+
total_loss += loss.detach().float()
136+
137+
return total_loss.item() / len(dataloader.dataset)
138+
113139

114140
def main():
115141
args = parse_args()
@@ -120,68 +146,25 @@ def main():
120146
if args.seed is not None:
121147
set_seed(args.seed)
122148

123-
train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
124-
train_dataloader = data.DataLoader(
149+
train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", dataset_dir=args.dataset_dir)
150+
151+
model = train(
125152
dataset=train_dataset,
126153
batch_size=args.train_batch_size,
127-
shuffle=True,
128-
drop_last=True,
154+
num_train_epochs=args.num_train_epochs,
155+
learning_rate=args.learning_rate,
156+
weight_decay=args.weight_decay,
129157
)
130-
model = construct_regression_mlp()
131-
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
132158

133-
logger.info("Start training the model.")
134-
model.train()
135-
for epoch in range(args.num_train_epochs):
136-
total_loss = 0
137-
with tqdm(train_dataloader, unit="batch") as tepoch:
138-
for batch in tepoch:
139-
tepoch.set_description(f"Epoch {epoch}")
140-
inputs, targets = batch
141-
outputs = model(inputs)
142-
loss = F.mse_loss(outputs, targets)
143-
total_loss += loss.detach().float()
144-
loss.backward()
145-
optimizer.step()
146-
optimizer.zero_grad()
147-
tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader))
148-
149-
logger.info("Start evaluating the model.")
150-
model.eval()
151-
train_eval_dataset = get_regression_dataset(
152-
data_name=args.dataset_name, split="eval_train", data_path=args.dataset_dir
159+
eval_train_dataset = get_regression_dataset(
160+
data_name=args.dataset_name, split="eval_train", dataset_dir=args.dataset_dir
153161
)
154-
train_eval_dataloader = DataLoader(
155-
dataset=train_eval_dataset,
156-
batch_size=args.eval_batch_size,
157-
shuffle=False,
158-
drop_last=False,
159-
)
160-
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir)
161-
eval_dataloader = DataLoader(
162-
dataset=eval_dataset,
163-
batch_size=args.eval_batch_size,
164-
shuffle=False,
165-
drop_last=False,
166-
)
167-
168-
total_loss = 0
169-
for batch in train_eval_dataloader:
170-
with torch.no_grad():
171-
inputs, targets = batch
172-
outputs = model(inputs)
173-
loss = F.mse_loss(outputs, targets, reduction="sum")
174-
total_loss += loss.detach().float()
175-
logger.info(f"Train loss {total_loss.item() / len(train_eval_dataloader.dataset)}")
162+
train_loss = evaluate(model=model, dataset=eval_train_dataset, batch_size=args.eval_batch_size)
163+
logger.info(f"Train loss: {train_loss}")
176164

177-
total_loss = 0
178-
for batch in eval_dataloader:
179-
with torch.no_grad():
180-
inputs, targets = batch
181-
outputs = model(inputs)
182-
loss = F.mse_loss(outputs, targets, reduction="sum")
183-
total_loss += loss.detach().float()
184-
logger.info(f"Evaluation loss {total_loss.item() / len(eval_dataloader.dataset)}")
165+
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", dataset_dir=args.dataset_dir)
166+
eval_loss = evaluate(model=model, dataset=eval_dataset, batch_size=args.eval_batch_size)
167+
logger.info(f"Evaluation loss: {eval_loss}")
185168

186169
if args.checkpoint_dir is not None:
187170
torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, "model.pth"))

kronfluence/analyzer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from typing import Optional
22

33
from accelerate.utils import extract_model_from_parallel
4+
from kronfluence.module.constants import FACTOR_TYPE
5+
6+
from factor.config import FactorConfig
47
from safetensors.torch import save_file
58
from torch import nn
69
from torch.utils import data
@@ -119,7 +122,7 @@ def fit_all_factors(
119122
dataloader_kwargs: Optional[DataLoaderKwargs] = None,
120123
factor_args: Optional[FactorArguments] = None,
121124
overwrite_output_dir: bool = False,
122-
) -> None:
125+
) -> Optional[FACTOR_TYPE]:
123126
"""Computes all necessary factors for the given factor strategy. As an example, EK-FAC
124127
requires (1) computing covariance matrices, (2) performing Eigendecomposition, and
125128
(3) computing Lambda (corrected-eigenvalues) matrices.
@@ -161,3 +164,11 @@ def fit_all_factors(
161164
factor_args=factor_args,
162165
overwrite_output_dir=overwrite_output_dir,
163166
)
167+
168+
if factor_args is None:
169+
factor_args = FactorArguments()
170+
strategy = factor_args.strategy
171+
factor_config = FactorConfig.CONFIGS[strategy]
172+
return self._load_all_required_factors(
173+
factors_name=factors_name, strategy=strategy, factor_config=factor_config
174+
)

tests/gpu_tests/cpu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
construct_mnist_mlp,
1717
get_mnist_dataset,
1818
)
19-
from tests.gpu_tests.prepare_tests import TRAIN_INDICES, QUERY_INDICES
19+
from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES
2020
from tests.utils import check_tensor_dict_equivalence
2121

2222
logging.basicConfig(level=logging.DEBUG)

tests/gpu_tests/prepare_tests.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
get_mnist_dataset,
1212
)
1313

14-
1514
# Pick difficult cases where the dataset is not perfectly divisible by batch size.
16-
TRAIN_INDICES = 59_999
17-
QUERY_INDICES = 50
15+
TRAIN_INDICES = 5_003
16+
QUERY_INDICES = 51
1817

1918

2019
def train() -> None:
@@ -82,6 +81,8 @@ def run_analysis() -> None:
8281

8382
train_dataset = get_mnist_dataset(split="train", data_path="data")
8483
eval_dataset = get_mnist_dataset(split="valid", data_path="data")
84+
train_dataset = Subset(train_dataset, indices=list(range(TRAIN_INDICES)))
85+
eval_dataset = Subset(eval_dataset, indices=list(range(QUERY_INDICES)))
8586

8687
task = ClassificationTask()
8788
model = model.double()
@@ -99,7 +100,6 @@ def run_analysis() -> None:
99100
gradient_covariance_dtype=torch.float64,
100101
lambda_dtype=torch.float64,
101102
lambda_iterative_aggregate=False,
102-
lambda_max_examples=1_000
103103
)
104104
analyzer.fit_all_factors(
105105
factors_name="single_gpu",
@@ -119,8 +119,6 @@ def run_analysis() -> None:
119119
factors_name="single_gpu",
120120
query_dataset=eval_dataset,
121121
train_dataset=train_dataset,
122-
train_indices=list(range(TRAIN_INDICES)),
123-
query_indices=list(range(QUERY_INDICES)),
124122
per_device_query_batch_size=12,
125123
per_device_train_batch_size=512,
126124
score_args=score_args,
@@ -130,7 +128,6 @@ def run_analysis() -> None:
130128
scores_name="single_gpu",
131129
factors_name="single_gpu",
132130
train_dataset=train_dataset,
133-
train_indices=list(range(TRAIN_INDICES)),
134131
per_device_train_batch_size=512,
135132
score_args=score_args,
136133
overwrite_output_dir=True,

0 commit comments

Comments
 (0)