-
Notifications
You must be signed in to change notification settings - Fork 10
/
run_lrs.py
69 lines (56 loc) · 2.45 KB
/
run_lrs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import argparse
import random
import numpy as np
import csv
from BYOD import lrs_metric
from BYOD.utils import get_model_n_tokenizer
def main(args):
torch.manual_seed(args.set_seed)
torch.cuda.manual_seed(args.set_seed)
random.seed(args.set_seed)
np.random.seed(args.set_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
# from transformers import AutoTokenizer, AutoModelForCausalLM
model, tokenizer = get_model_n_tokenizer(args.model_name, args=args)
print("Loading dataset wikitext data")
from datasets import load_dataset
# use train because it is bigger
wiki = load_dataset("wikitext", "wikitext-2-raw-v1", split="train").with_format("torch")
wiki = wiki.filter(lambda example: len(example["text"].split()) > 100)
lrs_mean, lrs_stderr, logits_diff = lrs_metric(
model, wiki, tokenizer, args.num_sentences_input, args.num_sentences_swap, args.max_examples
)
# result_row = [
# args.model_name,
# len(logits_diff),
# np.mean(logits_diff),
# np.std(logits_diff),
# np.median(logits_diff),
# args.dataset_name,
# args.set_seed,
# args.num_sentences_input,
# ]
# print(result_row)
# with open("context_sensitivity/lrs_results.csv", mode="a") as file:
# writer = csv.writer(file)
# writer.writerow(result_row)
with open("results.csv", mode="a") as file:
writer = csv.writer(file)
writer.writerow(
[args.model_name, "context", len(logits_diff), np.mean(logits_diff), np.std(logits_diff) / np.sqrt(len(logits_diff))]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="gpt2")
parser.add_argument("--dataset_name", type=str, default="wiki", help="wiki")
parser.add_argument("--num_sentences_input", type=int, default=3, help="Number of sentences in input")
parser.add_argument("--num_sentences_swap", type=int, default=2, help="Number of sentences in input")
parser.add_argument("--max_examples", type=int, default=1000)
parser.add_argument("--set_seed", type=int, default=42)
parser.add_argument("--fp16", default=False, type=bool)
parser.add_argument("--cache_dir_model", type=str, default="models")
parser.add_argument("--cache_dir_dataset", type=str, default="datasets")
args = parser.parse_args()
main(args)