Skip to content

Commit

Permalink
well distributed data
Browse files Browse the repository at this point in the history
  • Loading branch information
fightnyy committed Apr 16, 2021
1 parent 7f930e9 commit 2596dbe
Show file tree
Hide file tree
Showing 9 changed files with 358 additions and 197 deletions.
184 changes: 112 additions & 72 deletions .idea/workspace.xml

Large diffs are not rendered by default.

62 changes: 37 additions & 25 deletions main/bart_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@
"hyunwoongko/asian-bart-ja",
]

class AsianBartTokenizer(XLMRobertaTokenizer):

def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
class AsianBartTokenizer(XLMRobertaTokenizer):
def __init__(
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs
):
super().__init__(
*args,
tokenizer_file=tokenizer_file,
src_lang=src_lang,
tgt_lang=tgt_lang,
**kwargs
)
self.vocab_files_names = {"vocab_file": self.vocab_file}
self.max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
self.sp_model_size = len(self.sp_model)
Expand Down Expand Up @@ -48,9 +56,9 @@ def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **k
}

self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
self.fairseq_tokens_to_ids["<mask>"] = (len(self.sp_model) +
len(self.lang_code_to_id) +
self.fairseq_offset)
self.fairseq_tokens_to_ids["<mask>"] = (
len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
)

self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {
Expand Down Expand Up @@ -78,7 +86,7 @@ def prepare_seq2seq_batch(
src_langs: List[str],
tgt_texts: List[str] = None,
tgt_langs: List[str] = None,
padding = "max_length",
padding="max_length",
max_len: int = 256,
) -> Dict:

Expand Down Expand Up @@ -158,16 +166,17 @@ def add_language_tokens(self, tokens, langs):

special_tokens = torch.tensor([eos, lang], requires_grad=False)
input_id = torch.cat(
[input_id[:idx_to_add], special_tokens,
input_id[idx_to_add:]]).long()

additional_attention_mask = torch.tensor([1, 1],
requires_grad=False)
atn_mask = torch.cat([
atn_mask[:idx_to_add],
additional_attention_mask,
atn_mask[idx_to_add:],
]).long()
[input_id[:idx_to_add], special_tokens, input_id[idx_to_add:]]
).long()

additional_attention_mask = torch.tensor([1, 1], requires_grad=False)
atn_mask = torch.cat(
[
atn_mask[:idx_to_add],
additional_attention_mask,
atn_mask[idx_to_add:],
]
).long()

token_added_ids.append(input_id.unsqueeze(0))
token_added_masks.append(atn_mask.unsqueeze(0))
Expand All @@ -177,9 +186,8 @@ def add_language_tokens(self, tokens, langs):
return tokens

def build_inputs_with_special_tokens(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:

if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
Expand All @@ -201,13 +209,17 @@ def get_special_tokens_mask(
)
return list(
map(
lambda x: 1
if x in [self.sep_token_id, self.cls_token_id] else 0,
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
token_ids_0,
))
)
)
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1] * len(self.suffix_tokens)
if token_ids_1 is None:
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return (prefix_ones + ([0] * len(token_ids_0)) +
([0] * len(token_ids_1)) + suffix_ones)
return (
prefix_ones
+ ([0] * len(token_ids_0))
+ ([0] * len(token_ids_1))
+ suffix_ones
)
51 changes: 51 additions & 0 deletions main/ex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import numpy as np
from torch.utils.data import WeightedRandomSampler, DataLoader

if __name__ == "__main__":
numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack(
(
np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
np.ones(int(numDataPoints * 0.1), dtype=np.int32),
)
)

print(
"target train 0/1: {}/{}".format(
len(np.where(target == 0)[0]), len(np.where(target == 1)[0])
)
)

class_sample_count = np.array(
[len(np.where(target == t)[0]) for t in np.unique(target)]
)
weight = 1.0 / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)
import pdb

pdb.set_trace()
train_loader = DataLoader(
train_dataset, batch_size=bs, num_workers=1, sampler=sampler
)

for i, (data, target) in enumerate(train_loader):
print(
"batch index {}, 0/1: {}/{}".format(
i,
len(np.where(target.numpy() == 0)[0]),
len(np.where(target.numpy() == 1)[0]),
)
)
32 changes: 20 additions & 12 deletions main/get_bleu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
'''
"""
python get_bleu_score.py --hyp hyp.txt \
--ref ref.txt \
--lang en
'''
"""
import argparse
from pororo import Pororo
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
Expand All @@ -16,21 +16,24 @@ def get_hypotheses(lines, tokenizer):
hypotheses.append(tokens)
return hypotheses


def get_list_of_references(lines, tokenizer):
list_of_references = []
for line in lines:
tokens = tokenizer(line)
list_of_references.append([tokens])
return list_of_references


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--hyp', type=str, required=True,
help="system output file path")
parser.add_argument('--ref', type=str, required=True,
help="reference file path")
parser.add_argument('--lang', type=str, required=True, default="en",
help="en | ko | ja | zh")
parser.add_argument(
"--hyp", type=str, required=True, help="system output file path"
)
parser.add_argument("--ref", type=str, required=True, help="reference file path")
parser.add_argument(
"--lang", type=str, required=True, default="en", help="en | ko | ja | zh"
)
args = parser.parse_args()

hyp = args.hyp
Expand All @@ -42,12 +45,17 @@ def get_list_of_references(lines, tokenizer):
else:
tokenizer = Pororo(task="tokenization", lang="en")

hyp_lines = open(hyp, 'r', encoding="utf8").read().strip().splitlines()[:1000]
ref_lines = open(ref, 'r', encoding="utf8").read().strip().splitlines()[:1000]
hyp_lines = open(hyp, "r", encoding="utf8").read().strip().splitlines()[:1000]
ref_lines = open(ref, "r", encoding="utf8").read().strip().splitlines()[:1000]

assert len(hyp_lines) == len(ref_lines)

list_of_hypotheses = get_hypotheses(hyp_lines, tokenizer)
list_of_references = get_list_of_references(ref_lines, tokenizer)
score = corpus_bleu(list_of_references, list_of_hypotheses, auto_reweigh=True, smoothing_function=SmoothingFunction().method3)
print("BLEU SCORE = %.2f" % score)
score = corpus_bleu(
list_of_references,
list_of_hypotheses,
auto_reweigh=True,
smoothing_function=SmoothingFunction().method3,
)
print("BLEU SCORE = %.2f" % score)
15 changes: 9 additions & 6 deletions main/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@


class DistillBart(pl.LightningModule):

def __init__(self, n_encoder: int, n_decoder: int):
super().__init__()
self.batch_size = 16
self.lr = 3e-5
self.tokenizer = AsianBartTokenizer.from_pretrained("hyunwoongko/asian-bart-ecjk")
self.tokenizer = AsianBartTokenizer.from_pretrained(
"hyunwoongko/asian-bart-ecjk"
)
self.model = start(n_encoder, n_decoder)

def forward(self, batch):
Expand All @@ -30,8 +31,11 @@ def forward(self, batch):
for key, v in model_inputs.items():
model_inputs[key] = model_inputs[key].to("cuda")

out = self.model(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'],
labels=model_inputs['labels'])
out = self.model(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
labels=model_inputs["labels"],
)
return out

def training_step(self, batch, batch_idx):
Expand All @@ -56,10 +60,9 @@ def validation_step(self, batch, batch_idx) -> Dict:
"""
out = self.forward(batch)
loss = out["loss"]
self.log('val_loss', loss, on_step=True, prog_bar=True, logger=True)
self.log("val_loss", loss, on_step=True, prog_bar=True, logger=True)
return loss

def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.lr)
return {"optimizer": optimizer}

26 changes: 13 additions & 13 deletions main/make_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@
"""

teacher_model = AsianBartForConditionalGeneration.from_pretrained(
"hyunwoongko/asian-bart-ecjk"
)
"hyunwoongko/asian-bart-ecjk"
)
decoder_layer_3 = [
"decoder.layers.0",
"decoder.layers.1",
"decoder.layers.2",
"decoder.layers.4",
"decoder.layers.5",
"decoder.layers.6",
"decoder.layers.8",
"decoder.layers.9",
"decoder.layers.10",
]
"decoder.layers.0",
"decoder.layers.1",
"decoder.layers.2",
"decoder.layers.4",
"decoder.layers.5",
"decoder.layers.6",
"decoder.layers.8",
"decoder.layers.9",
"decoder.layers.10",
]


def start(num_encoder: int, num_decoder: int) -> nn.Module:
distill_12_3_config = make_config(num_decoder, num_encoder)
Expand Down Expand Up @@ -74,4 +75,3 @@ def check(k: List[str], decoder_layer_3: List[str]):
if except_layer in k:
return True
return False

Loading

0 comments on commit 2596dbe

Please sign in to comment.