-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
122 lines (92 loc) · 4.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Sam Greydanus | 2024
########## IMPORTS AND A FEW GLOBAL VARIABLES ##########
import os, sys, time, getpass
from typing import Optional
from dataclasses import dataclass
import wandb
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from model import get_checkpoint, save_checkpoint, get_all_args
from sample import save_samples
from data import InfiniteDataLoader, create_datasets
@torch.inference_mode()
def evaluate(model, dataset, batch_size=15, max_batches=None):
model.eval()
loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=0)
losses = []
for i, batch in enumerate(loader):
batch = [t.to(args.device) for t in batch]
X, C, Y = batch
logits, loss = model(X, C, Y)
losses.append(loss.item())
if max_batches is not None and i >= max_batches:
break
mean_loss = torch.tensor(losses).mean().item()
model.train() # reset model back to training mode
return mean_loss
########## ARGS, LOGGING, AND TRAIN LOOP ##########
if __name__ == '__main__':
args = get_all_args()
torch.manual_seed(args.seed) # system inits
torch.cuda.manual_seed_all(args.seed)
wandb_init_args = {"project": args.wandb_project, "entity": args.wandb_entity, "config": args}
if args.load_from_run_id:
wandb_init_args["id"] = args.load_from_run_id
wandb_init_args["resume"] = "must"
else:
wandb_init_args["name"] = args.wandb_run_name
wandb.init(**wandb_init_args)
train_dataset, test_dataset = create_datasets(args) # init datasets
args.vocab_size = train_dataset.get_vocab_size()
args.block_size = train_dataset.get_stroke_seq_length()
args.context_block_size = train_dataset.get_text_seq_length()
args.context_vocab_size = train_dataset.get_char_vocab_size()
print(f"Dataset determined that: {args.vocab_size=}, {args.block_size=}")
model, optimizer, scheduler, step, best_loss = get_checkpoint(args, sample_only=False)
batch_loader = InfiniteDataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=4)
wandb.watch(model, log="all", log_freq=args.log_every, log_graph=False) # model saving stuff
########## ARGS, LOGGING, AND TRAIN LOOP ##########
# training loop
while True:
t0 = time.time()
# get the next batch, ship to device, and unpack it to input and target
batch = batch_loader.next()
X, C, Y = [t.to(args.device) for t in batch]
# feed into the model
logits, loss = model(X, C, Y)
# calculate the gradient, update the weights
model.zero_grad(set_to_none=True) ; loss.backward()
optimizer.step() ; scheduler.step()
wandb.log({"train_loss_step": loss.item(), "step": step})
t1 = time.time()
# logging
if step % args.print_every == 0:
print(f"step {step} | loss {loss.item():.4f} | step time {(t1-t0)*1000:.2f}ms | lr {scheduler.get_last_lr()[0]:.6f}")
# evaluate the model
if step > 0 and step % args.log_every == 0:
train_loss = evaluate(model, train_dataset, batch_size=100, max_batches=10)
test_loss = evaluate(model, test_dataset, batch_size=100, max_batches=10)
wandb.log({"train_loss": train_loss, "test_loss": test_loss, "step": step })
print(f"step {step} train loss: {train_loss:.4f} test loss: {test_loss:.4f}")
if best_loss is None or test_loss < best_loss: # save the model to W&B if it has improved
best_loss = test_loss
print(f"Test loss {test_loss:.4f} is the best so far, saving checkpoint to {args.local_checkpoint_path}")
save_checkpoint(model, args.local_checkpoint_path, optimizer, scheduler, step, best_loss)
artifact = wandb.Artifact('best_checkpoint', type='model')
artifact.add_file(args.local_checkpoint_path)
wandb.log_artifact(artifact)
# sample from the model
if step > 0 and step % args.log_every == 0:
save_samples(model, test_dataset, num=6, do_sample=True)
save_samples(model, test_dataset, num=6, do_sample=False)
save_samples(model, train_dataset, num=3, do_sample=True)
save_samples(model, train_dataset, num=3, do_sample=False)
step += 1
# termination conditions
if args.max_steps >= 0 and step >= args.max_steps:
break
wandb.finish()