-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
378 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import torch.nn as nn | ||
import math | ||
|
||
class Embeddings(nn.Module): | ||
def __init__(self, d_model, vocab): | ||
super(Embeddings, self).__init__() | ||
self.lut = nn.Embedding(vocab, d_model) | ||
self.d_model = d_model | ||
|
||
def forward(self, x): | ||
return self.lut(x) * math.sqrt(self.d_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import MultiHeadedAttention | ||
import PositionwiseFeedForward | ||
import PositionalEncoding | ||
import EncoderDecoder | ||
from encoder import Encoder, EncoderLayer | ||
from decoder import Decoder, DecoderLayer | ||
import torch.nn as nn | ||
import Generator | ||
import Embeddings | ||
import copy | ||
|
||
|
||
def make_model( | ||
src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1 | ||
): | ||
"Helper: Construct a model from hyperparameters." | ||
c = copy.deepcopy | ||
attn = MultiHeadedAttention(h, d_model) | ||
ff = PositionwiseFeedForward(d_model, d_ff, dropout) | ||
position = PositionalEncoding(d_model, dropout) | ||
model = EncoderDecoder( | ||
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), | ||
Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N), | ||
nn.Sequential(Embeddings(d_model, src_vocab), c(position)), | ||
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), | ||
Generator(d_model, tgt_vocab), | ||
) | ||
|
||
# This was important from their code. | ||
# Initialize parameters with Glorot / fan_avg. | ||
for p in model.parameters(): | ||
if p.dim() > 1: | ||
nn.init.xavier_uniform_(p) | ||
return model |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import torch | ||
from torchtext.data.functional import to_map_style_dataset, pad | ||
import torchtext.datasets as datasets | ||
from DataGen import tokenize | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.distributed import DistributedSampler | ||
|
||
def collate_batch( | ||
batch, | ||
src_pipeline, | ||
tgt_pipeline, | ||
src_vocab, | ||
tgt_vocab, | ||
device, | ||
max_padding=128, | ||
pad_id=2, | ||
): | ||
bs_id = torch.tensor([0], device=device) # <s> token id | ||
eos_id = torch.tensor([1], device=device) # </s> token id | ||
src_list, tgt_list = [], [] | ||
for (_src, _tgt) in batch: | ||
processed_src = torch.cat( | ||
[ | ||
bs_id, | ||
torch.tensor( | ||
src_vocab(src_pipeline(_src)), | ||
dtype=torch.int64, | ||
device=device, | ||
), | ||
eos_id, | ||
], | ||
0, | ||
) | ||
processed_tgt = torch.cat( | ||
[ | ||
bs_id, | ||
torch.tensor( | ||
tgt_vocab(tgt_pipeline(_tgt)), | ||
dtype=torch.int64, | ||
device=device, | ||
), | ||
eos_id, | ||
], | ||
0, | ||
) | ||
src_list.append( | ||
# warning - overwrites values for negative values of padding - len | ||
pad( | ||
processed_src, | ||
( | ||
0, | ||
max_padding - len(processed_src), | ||
), | ||
value=pad_id, | ||
) | ||
) | ||
tgt_list.append( | ||
pad( | ||
processed_tgt, | ||
(0, max_padding - len(processed_tgt)), | ||
value=pad_id, | ||
) | ||
) | ||
|
||
src = torch.stack(src_list) | ||
tgt = torch.stack(tgt_list) | ||
return (src, tgt) | ||
|
||
def create_dataloaders( | ||
device, | ||
vocab_src, | ||
vocab_tgt, | ||
spacy_de, | ||
spacy_en, | ||
batch_size=12000, | ||
max_padding=128, | ||
is_distributed=True, | ||
): | ||
# def create_dataloaders(batch_size=12000): | ||
def tokenize_de(text): | ||
return tokenize(text, spacy_de) | ||
|
||
def tokenize_en(text): | ||
return tokenize(text, spacy_en) | ||
|
||
def collate_fn(batch): | ||
return collate_batch( | ||
batch, | ||
tokenize_de, | ||
tokenize_en, | ||
vocab_src, | ||
vocab_tgt, | ||
device, | ||
max_padding=max_padding, | ||
pad_id=vocab_src.get_stoi()["<blank>"], | ||
) | ||
|
||
train_iter, valid_iter, test_iter = datasets.Multi30k( | ||
language_pair=("de", "en") | ||
) | ||
|
||
train_iter_map = to_map_style_dataset( | ||
train_iter | ||
) # DistributedSampler needs a dataset len() | ||
train_sampler = ( | ||
DistributedSampler(train_iter_map) if is_distributed else None | ||
) | ||
valid_iter_map = to_map_style_dataset(valid_iter) | ||
valid_sampler = ( | ||
DistributedSampler(valid_iter_map) if is_distributed else None | ||
) | ||
|
||
train_dataloader = DataLoader( | ||
train_iter_map, | ||
batch_size=batch_size, | ||
shuffle=(train_sampler is None), | ||
sampler=train_sampler, | ||
collate_fn=collate_fn, | ||
) | ||
valid_dataloader = DataLoader( | ||
valid_iter_map, | ||
batch_size=batch_size, | ||
shuffle=(valid_sampler is None), | ||
sampler=valid_sampler, | ||
collate_fn=collate_fn, | ||
) | ||
return train_dataloader, valid_dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import torch | ||
from model import ModelGen | ||
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from trainning import LabelSmoothing,Train, SimpleLossCompute, Batch | ||
import Iterators,DataGen | ||
from torch.optim.lr_scheduler import LambdaLR | ||
import GPUtil | ||
import os | ||
import torch.multiprocessing as mp | ||
from os.path import exists | ||
from Utils import DummyOptimizer, DummyScheduler | ||
|
||
|
||
|
||
def train_worker( | ||
gpu, | ||
ngpus_per_node, | ||
vocab_src, | ||
vocab_tgt, | ||
spacy_de, | ||
spacy_en, | ||
config, | ||
is_distributed=False, | ||
): | ||
print(f"Train worker process using GPU: {gpu} for training", flush=True) | ||
torch.cuda.set_device(gpu) | ||
|
||
pad_idx = vocab_tgt["<blank>"] | ||
d_model = 512 | ||
model = ModelGen.make_model(len(vocab_src), len(vocab_tgt), N=6) | ||
model.cuda(gpu) | ||
module = model | ||
is_main_process = True | ||
if is_distributed: | ||
dist.init_process_group( | ||
"nccl", init_method="env://", rank=gpu, world_size=ngpus_per_node | ||
) | ||
model = DDP(model, device_ids=[gpu]) | ||
module = model.module | ||
is_main_process = gpu == 0 | ||
|
||
criterion = LabelSmoothing( | ||
size=len(vocab_tgt), padding_idx=pad_idx, smoothing=0.1 | ||
) | ||
criterion.cuda(gpu) | ||
|
||
train_dataloader, valid_dataloader = Iterators.create_dataloaders( | ||
gpu, | ||
vocab_src, | ||
vocab_tgt, | ||
spacy_de, | ||
spacy_en, | ||
batch_size=config["batch_size"] // ngpus_per_node, | ||
max_padding=config["max_padding"], | ||
is_distributed=is_distributed, | ||
) | ||
|
||
optimizer = torch.optim.Adam( | ||
model.parameters(), lr=config["base_lr"], betas=(0.9, 0.98), eps=1e-9 | ||
) | ||
lr_scheduler = LambdaLR( | ||
optimizer=optimizer, | ||
lr_lambda=lambda step: Train.rate( | ||
step, d_model, factor=1, warmup=config["warmup"] | ||
), | ||
) | ||
train_state = Train.TrainState() | ||
|
||
for epoch in range(config["num_epochs"]): | ||
if is_distributed: | ||
train_dataloader.sampler.set_epoch(epoch) | ||
valid_dataloader.sampler.set_epoch(epoch) | ||
|
||
model.train() | ||
print(f"[GPU{gpu}] Epoch {epoch} Training ====", flush=True) | ||
_, train_state = Train.run_epoch( | ||
(Batch(b[0], b[1], pad_idx) for b in train_dataloader), | ||
model, | ||
SimpleLossCompute(module.generator, criterion), | ||
optimizer, | ||
lr_scheduler, | ||
mode="train+log", | ||
accum_iter=config["accum_iter"], | ||
train_state=train_state, | ||
) | ||
|
||
GPUtil.showUtilization() | ||
if is_main_process: | ||
file_path = "%s%.2d.pt" % (config["file_prefix"], epoch) | ||
torch.save(module.state_dict(), file_path) | ||
torch.cuda.empty_cache() | ||
|
||
print(f"[GPU{gpu}] Epoch {epoch} Validation ====", flush=True) | ||
model.eval() | ||
sloss = Train.run_epoch( | ||
(Batch(b[0], b[1], pad_idx) for b in valid_dataloader), | ||
model, | ||
SimpleLossCompute(module.generator, criterion), | ||
DummyOptimizer(), | ||
DummyScheduler(), | ||
mode="eval", | ||
) | ||
print(sloss) | ||
torch.cuda.empty_cache() | ||
|
||
if is_main_process: | ||
file_path = "%sfinal.pt" % config["file_prefix"] | ||
torch.save(module.state_dict(), file_path) | ||
|
||
def train_distributed_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config): | ||
|
||
ngpus = torch.cuda.device_count() | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12356" | ||
print(f"Number of GPUs detected: {ngpus}") | ||
print("Spawning training processes ...") | ||
mp.spawn( | ||
train_worker, | ||
nprocs=ngpus, | ||
args=(ngpus, vocab_src, vocab_tgt, spacy_de, spacy_en, config, True), | ||
) | ||
|
||
|
||
def train_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config): | ||
if config["distributed"]: | ||
train_distributed_model( | ||
vocab_src, vocab_tgt, spacy_de, spacy_en, config | ||
) | ||
else: | ||
train_worker( | ||
0, 1, vocab_src, vocab_tgt, spacy_de, spacy_en, config, False | ||
) | ||
|
||
|
||
def load_trained_model(): | ||
config = { | ||
"batch_size": 32, | ||
"distributed": False, | ||
"num_epochs": 8, | ||
"accum_iter": 10, | ||
"base_lr": 1.0, | ||
"max_padding": 72, | ||
"warmup": 3000, | ||
"file_prefix": "multi30k_model_", | ||
} | ||
model_path = "multi30k_model_final.pt" | ||
if not exists(model_path): | ||
spacy_de, spacy_en = DataGen.load_tokenizers() | ||
vocab_src, vocab_tgt = DataGen.load_vocab(spacy_de, spacy_en) | ||
train_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config) | ||
|
||
model = ModelGen.make_model(len(vocab_src), len(vocab_tgt), N=6) | ||
model.load_state_dict(torch.load("multi30k_model_final.pt")) | ||
return model | ||
|
||
|
||
model = load_trained_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
class SimpleLossCompute: | ||
"A simple loss compute and train function." | ||
|
||
def __init__(self, generator, criterion): | ||
self.generator = generator | ||
self.criterion = criterion | ||
|
||
def __call__(self, x, y, norm): | ||
x = self.generator(x) | ||
sloss = ( | ||
self.criterion( | ||
x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1) | ||
) | ||
/ norm | ||
) | ||
return sloss.data * norm, sloss |