-
Notifications
You must be signed in to change notification settings - Fork 0
/
accelerate_train.py
457 lines (379 loc) · 14.9 KB
/
accelerate_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
import logging
import math
import os
from argparse import Namespace
from pathlib import Path
import json
import pandas as pd
import transformers
import evaluate
import torch
from torch.optim import AdamW
from transformers import (
VisionEncoderDecoderModel,
TrOCRProcessor,
HfArgumentParser,
get_scheduler,
default_data_collator,
set_seed,
)
from accelerate.logging import get_logger
from accelerate import Accelerator, DistributedType
from huggingface_hub import Repository, create_repo
import datasets
from torch.utils.data.dataloader import DataLoader
from arguments import TrainingArguments
from dataset import HandWrittenDataset, KTRDataset
from tqdm.auto import tqdm
from transformers.utils import check_min_version, get_full_repo_name
check_min_version("4.31.0.dev0")
logger = get_logger(__name__)
def create_dataloaders(args):
if not args.handwritten_dataset:
from sklearn.model_selection import train_test_split
df = pd.read_csv(args.csv_path)
train_df, test_df = train_test_split(
df, test_size=args.test_split, random_state=args.seed
)
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
train_dataset = KTRDataset(
root_dir=args.root_dir,
df=train_df,
processor=processor,
max_target_length=args.max_length,
)
eval_dataset = KTRDataset(
root_dir=args.root_dir,
df=test_df,
processor=processor,
max_target_length=args.max_length,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.train_batch_size,
collate_fn=default_data_collator,
shuffle=True,
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.valid_batch_size,
collate_fn=default_data_collator,
)
return train_dataloader, eval_dataloader
else:
train_dataset = HandWrittenDataset(
root_dir=args.root_dir,
train=True,
processor=processor,
max_target_length=args.max_length,
test_split=args.test_split,
)
eval_dataset = HandWrittenDataset(
root_dir=args.root_dir,
train=False,
processor=processor,
max_target_length=args.max_length,
test_split=args.test_split,
)
train_dataloader = DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True
)
eval_dataloader = DataLoader(eval_dataset, batch_size=args.valid_batch_size)
return train_dataloader, eval_dataloader
def compute_tflops(elapsed_time, accelerator, args):
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
config_model = accelerator.unwrap_model(model).config
checkpoint_factor = 4 if args.gradient_checkpointing else 3
batch_size = (
args.train_batch_size
* accelerator.state.num_processes
* args.gradient_accumulation_steps
)
factor = (
24
* checkpoint_factor
* batch_size
* args.max_length
* (
config_model.encoder.num_hidden_layers
+ config_model.decoder.num_hidden_layers
)
* ((config_model.encoder.hidden_size + config_model.decoder.hidden_size) ** 2)
)
flops_per_iteration = factor * (
1.0
+ (
args.max_length
/ (
6.0
* (config_model.encoder.hidden_size + config_model.decoder.hidden_size)
)
)
+ (
processor.tokenizer.vocab_size
/ (
16.0
* (
config_model.encoder.num_hidden_layers
+ config_model.decoder.num_hidden_layers
)
* (config_model.encoder.hidden_size + config_model.decoder.hidden_size)
)
)
)
tflops = flops_per_iteration / (
elapsed_time * accelerator.state.num_processes * (10**12)
)
return tflops
parser = HfArgumentParser(TrainingArguments)
args = parser.parse_args()
# Sanity check
# if the dataset is not handwritten we must have csv_path of the labels
if not args.handwritten_dataset:
assert args.csv_path is not None, "Please provide csv_path"
# Accelerator
accelerator = Accelerator(
log_with=["wandb", "tensorboard"], project_dir=f"{args.output_dir}/log"
)
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
args = Namespace(**vars(args), **acc_state)
samples_per_step = accelerator.state.num_processes * args.train_batch_size
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.model_ckpt is None:
repo_name = get_full_repo_name(
Path(args.output_dir).name, token=args.hub_token
)
else:
repo_name = args.model_ckpt
create_repo(repo_name, exist_ok=True, token=args.hub_token)
hf_repo = Repository(
args.output_dir, clone_from=repo_name, token=args.hub_token
)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
model = VisionEncoderDecoderModel.from_pretrained(args.model_ckpt)
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
processor = TrOCRProcessor.from_pretrained(args.model_ckpt)
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Load dataset and dataloader
train_dataloader, eval_dataloader = create_dataloaders(args)
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
# Use the device given by the `accelerator` object.
device = accelerator.device
model.to(device)
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
accelerator.register_for_checkpointing(lr_scheduler)
def get_lr():
return optimizer.param_groups[0]["lr"]
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)
if args.with_tracking:
experiment_config = vars(args)
accelerator.init_trackers("trocr-ckb", experiment_config)
cer_metric = evaluate.load("cer")
wer_metric = evaluate.load("wer")
total_batch_size = (
args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
progress_bar = tqdm(
range(args.max_train_steps), disable=not accelerator.is_local_main_process
)
# Train model
completed_steps = 0
starting_epoch = 0
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
# Sorts folders by date modified, most recent checkpoint is the last
path = dirs[-1]
# Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
completed_steps = starting_epoch * num_update_steps_per_epoch
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
resume_step = (
int(training_difference.replace("step_", ""))
* args.gradient_accumulation_steps
)
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
completed_steps = resume_step // args.gradient_accumulation_steps
progress_bar.update(completed_steps)
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0
if (
args.resume_from_checkpoint
and epoch == starting_epoch
and resume_step is not None
):
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
active_dataloader = accelerator.skip_first_batches(
train_dataloader, resume_step
)
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
loss = model(**batch).loss
if args.with_tracking:
total_loss += loss.detach().float()
# log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()})
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps }"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps:
break
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
with torch.inference_mode():
loss = model(**batch).loss
outputs = model.generate(**batch["pixel_values"])
pred_str = processor.batch_decode(outputs, skip_special_tokens=True)
label_ids = batch["labels"]
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
predictions, references = accelerator.gather_for_metrics([pred_str, label_str])
losses.append(
accelerator.gather_for_metrics(loss.repeat(args.valid_batch_size))
)
cer_metric.add_batch(predictions=predictions, references=references)
wer_metric.add_batch(predictions=predictions, references=references)
losses = torch.cat(losses)
eval_loss = torch.mean(losses)
cer = cer_metric.compute()
wer = wer_metric.compute()
logger.info(
f"epoch {epoch}: cer: {cer:.3f} wer: {wer:.3f} eval_loss: {eval_loss:.5f}"
)
if args.with_tracking:
accelerator.log(
{
"train_loss": f"{total_loss.item() / len(train_dataloader):.5f}",
"eval_loss": f"{eval_loss.item():.5f}",
"cer": f"{cer:.5f}",
"wer": f"{wer:.5f}",
"epoch": epoch,
"completed_steps": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)
if accelerator.is_main_process:
processor.save_pretrained(args.output_dir)
hf_repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}",
blocking=False,
auto_lfs_prune=True,
)
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)
if accelerator.is_main_process:
processor.save_pretrained(args.output_dir)
if args.push_to_hub:
hf_repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"wer": wer, "cer": cer}, f)