From 2596dbead344f4a502c274a3547387fc3b97192a Mon Sep 17 00:00:00 2001 From: fightnyy Date: Sat, 17 Apr 2021 01:45:05 +0900 Subject: [PATCH] well distributed data --- .idea/workspace.xml | 184 +++++++++++++++++++++++++---------------- main/bart_tokenizer.py | 62 ++++++++------ main/ex.py | 51 ++++++++++++ main/get_bleu.py | 32 ++++--- main/main.py | 15 ++-- main/make_config.py | 26 +++--- main/preprocessing.py | 96 ++++++++++++++------- main/run.py | 10 ++- main/test.py | 79 +++++++++--------- 9 files changed, 358 insertions(+), 197 deletions(-) create mode 100644 main/ex.py diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 203c342..ce4cdcf 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -2,8 +2,14 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -27,14 +56,11 @@ - + - - - - - + + @@ -42,8 +68,8 @@ - - + + @@ -51,8 +77,8 @@ - - + + @@ -63,8 +89,8 @@ - - + + @@ -73,7 +99,7 @@ - + @@ -81,11 +107,11 @@ - + - - + + @@ -96,8 +122,8 @@ - - + + @@ -129,6 +155,7 @@ print ja tok + zh_CN @@ -143,21 +170,21 @@ - @@ -209,7 +236,29 @@ - + + + + + @@ -276,12 +326,13 @@ - + + - @@ -296,7 +347,7 @@ - + @@ -309,7 +360,7 @@ - + @@ -340,93 +391,82 @@ - - - - - - - - - - - + + + - - + + - + - - + + - + - + - - + + - - + + - + - - + + - + - - + + - + - + - - - - - + + - + - - + + - + - + - - + + - + diff --git a/main/bart_tokenizer.py b/main/bart_tokenizer.py index b6b7e85..7c9b736 100644 --- a/main/bart_tokenizer.py +++ b/main/bart_tokenizer.py @@ -10,10 +10,18 @@ "hyunwoongko/asian-bart-ja", ] -class AsianBartTokenizer(XLMRobertaTokenizer): - def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs): - super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs) +class AsianBartTokenizer(XLMRobertaTokenizer): + def __init__( + self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs + ): + super().__init__( + *args, + tokenizer_file=tokenizer_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + **kwargs + ) self.vocab_files_names = {"vocab_file": self.vocab_file} self.max_model_input_sizes = {m: 1024 for m in _all_mbart_models} self.sp_model_size = len(self.sp_model) @@ -48,9 +56,9 @@ def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **k } self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} - self.fairseq_tokens_to_ids[""] = (len(self.sp_model) + - len(self.lang_code_to_id) + - self.fairseq_offset) + self.fairseq_tokens_to_ids[""] = ( + len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + ) self.fairseq_tokens_to_ids.update(self.lang_code_to_id) self.fairseq_ids_to_tokens = { @@ -78,7 +86,7 @@ def prepare_seq2seq_batch( src_langs: List[str], tgt_texts: List[str] = None, tgt_langs: List[str] = None, - padding = "max_length", + padding="max_length", max_len: int = 256, ) -> Dict: @@ -158,16 +166,17 @@ def add_language_tokens(self, tokens, langs): special_tokens = torch.tensor([eos, lang], requires_grad=False) input_id = torch.cat( - [input_id[:idx_to_add], special_tokens, - input_id[idx_to_add:]]).long() - - additional_attention_mask = torch.tensor([1, 1], - requires_grad=False) - atn_mask = torch.cat([ - atn_mask[:idx_to_add], - additional_attention_mask, - atn_mask[idx_to_add:], - ]).long() + [input_id[:idx_to_add], special_tokens, input_id[idx_to_add:]] + ).long() + + additional_attention_mask = torch.tensor([1, 1], requires_grad=False) + atn_mask = torch.cat( + [ + atn_mask[:idx_to_add], + additional_attention_mask, + atn_mask[idx_to_add:], + ] + ).long() token_added_ids.append(input_id.unsqueeze(0)) token_added_masks.append(atn_mask.unsqueeze(0)) @@ -177,9 +186,8 @@ def add_language_tokens(self, tokens, langs): return tokens def build_inputs_with_special_tokens( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: if token_ids_1 is None: return self.prefix_tokens + token_ids_0 + self.suffix_tokens @@ -201,13 +209,17 @@ def get_special_tokens_mask( ) return list( map( - lambda x: 1 - if x in [self.sep_token_id, self.cls_token_id] else 0, + lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0, - )) + ) + ) prefix_ones = [1] * len(self.prefix_tokens) suffix_ones = [1] * len(self.suffix_tokens) if token_ids_1 is None: return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones - return (prefix_ones + ([0] * len(token_ids_0)) + - ([0] * len(token_ids_1)) + suffix_ones) \ No newline at end of file + return ( + prefix_ones + + ([0] * len(token_ids_0)) + + ([0] * len(token_ids_1)) + + suffix_ones + ) diff --git a/main/ex.py b/main/ex.py new file mode 100644 index 0000000..26786a5 --- /dev/null +++ b/main/ex.py @@ -0,0 +1,51 @@ +import torch +import numpy as np +from torch.utils.data import WeightedRandomSampler, DataLoader + +if __name__ == "__main__": + numDataPoints = 1000 + data_dim = 5 + bs = 100 + + # Create dummy data with class imbalance 9 to 1 + data = torch.FloatTensor(numDataPoints, data_dim) + target = np.hstack( + ( + np.zeros(int(numDataPoints * 0.9), dtype=np.int32), + np.ones(int(numDataPoints * 0.1), dtype=np.int32), + ) + ) + + print( + "target train 0/1: {}/{}".format( + len(np.where(target == 0)[0]), len(np.where(target == 1)[0]) + ) + ) + + class_sample_count = np.array( + [len(np.where(target == t)[0]) for t in np.unique(target)] + ) + weight = 1.0 / class_sample_count + samples_weight = np.array([weight[t] for t in target]) + + samples_weight = torch.from_numpy(samples_weight) + samples_weigth = samples_weight.double() + sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) + + target = torch.from_numpy(target).long() + train_dataset = torch.utils.data.TensorDataset(data, target) + import pdb + + pdb.set_trace() + train_loader = DataLoader( + train_dataset, batch_size=bs, num_workers=1, sampler=sampler + ) + + for i, (data, target) in enumerate(train_loader): + print( + "batch index {}, 0/1: {}/{}".format( + i, + len(np.where(target.numpy() == 0)[0]), + len(np.where(target.numpy() == 1)[0]), + ) + ) diff --git a/main/get_bleu.py b/main/get_bleu.py index 8b337da..2fec894 100644 --- a/main/get_bleu.py +++ b/main/get_bleu.py @@ -1,8 +1,8 @@ -''' +""" python get_bleu_score.py --hyp hyp.txt \ --ref ref.txt \ --lang en -''' +""" import argparse from pororo import Pororo from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction @@ -16,6 +16,7 @@ def get_hypotheses(lines, tokenizer): hypotheses.append(tokens) return hypotheses + def get_list_of_references(lines, tokenizer): list_of_references = [] for line in lines: @@ -23,14 +24,16 @@ def get_list_of_references(lines, tokenizer): list_of_references.append([tokens]) return list_of_references + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--hyp', type=str, required=True, - help="system output file path") - parser.add_argument('--ref', type=str, required=True, - help="reference file path") - parser.add_argument('--lang', type=str, required=True, default="en", - help="en | ko | ja | zh") + parser.add_argument( + "--hyp", type=str, required=True, help="system output file path" + ) + parser.add_argument("--ref", type=str, required=True, help="reference file path") + parser.add_argument( + "--lang", type=str, required=True, default="en", help="en | ko | ja | zh" + ) args = parser.parse_args() hyp = args.hyp @@ -42,12 +45,17 @@ def get_list_of_references(lines, tokenizer): else: tokenizer = Pororo(task="tokenization", lang="en") - hyp_lines = open(hyp, 'r', encoding="utf8").read().strip().splitlines()[:1000] - ref_lines = open(ref, 'r', encoding="utf8").read().strip().splitlines()[:1000] + hyp_lines = open(hyp, "r", encoding="utf8").read().strip().splitlines()[:1000] + ref_lines = open(ref, "r", encoding="utf8").read().strip().splitlines()[:1000] assert len(hyp_lines) == len(ref_lines) list_of_hypotheses = get_hypotheses(hyp_lines, tokenizer) list_of_references = get_list_of_references(ref_lines, tokenizer) - score = corpus_bleu(list_of_references, list_of_hypotheses, auto_reweigh=True, smoothing_function=SmoothingFunction().method3) - print("BLEU SCORE = %.2f" % score) \ No newline at end of file + score = corpus_bleu( + list_of_references, + list_of_hypotheses, + auto_reweigh=True, + smoothing_function=SmoothingFunction().method3, + ) + print("BLEU SCORE = %.2f" % score) diff --git a/main/main.py b/main/main.py index c4d78bc..9e37143 100644 --- a/main/main.py +++ b/main/main.py @@ -9,12 +9,13 @@ class DistillBart(pl.LightningModule): - def __init__(self, n_encoder: int, n_decoder: int): super().__init__() self.batch_size = 16 self.lr = 3e-5 - self.tokenizer = AsianBartTokenizer.from_pretrained("hyunwoongko/asian-bart-ecjk") + self.tokenizer = AsianBartTokenizer.from_pretrained( + "hyunwoongko/asian-bart-ecjk" + ) self.model = start(n_encoder, n_decoder) def forward(self, batch): @@ -30,8 +31,11 @@ def forward(self, batch): for key, v in model_inputs.items(): model_inputs[key] = model_inputs[key].to("cuda") - out = self.model(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'], - labels=model_inputs['labels']) + out = self.model( + input_ids=model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], + labels=model_inputs["labels"], + ) return out def training_step(self, batch, batch_idx): @@ -56,10 +60,9 @@ def validation_step(self, batch, batch_idx) -> Dict: """ out = self.forward(batch) loss = out["loss"] - self.log('val_loss', loss, on_step=True, prog_bar=True, logger=True) + self.log("val_loss", loss, on_step=True, prog_bar=True, logger=True) return loss def configure_optimizers(self): optimizer = AdamW(self.model.parameters(), lr=self.lr) return {"optimizer": optimizer} - diff --git a/main/make_config.py b/main/make_config.py index ea6c312..f236096 100644 --- a/main/make_config.py +++ b/main/make_config.py @@ -16,19 +16,20 @@ """ teacher_model = AsianBartForConditionalGeneration.from_pretrained( - "hyunwoongko/asian-bart-ecjk" - ) + "hyunwoongko/asian-bart-ecjk" +) decoder_layer_3 = [ - "decoder.layers.0", - "decoder.layers.1", - "decoder.layers.2", - "decoder.layers.4", - "decoder.layers.5", - "decoder.layers.6", - "decoder.layers.8", - "decoder.layers.9", - "decoder.layers.10", - ] + "decoder.layers.0", + "decoder.layers.1", + "decoder.layers.2", + "decoder.layers.4", + "decoder.layers.5", + "decoder.layers.6", + "decoder.layers.8", + "decoder.layers.9", + "decoder.layers.10", +] + def start(num_encoder: int, num_decoder: int) -> nn.Module: distill_12_3_config = make_config(num_decoder, num_encoder) @@ -74,4 +75,3 @@ def check(k: List[str], decoder_layer_3: List[str]): if except_layer in k: return True return False - diff --git a/main/preprocessing.py b/main/preprocessing.py index 34bcdae..de5492d 100644 --- a/main/preprocessing.py +++ b/main/preprocessing.py @@ -1,9 +1,9 @@ from typing import Tuple -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, WeightedRandomSampler from pdb import set_trace - +import numpy as np import math -import os +import torch paws_train_set = [] paws_val_set = [] @@ -38,32 +38,72 @@ def load_multilingual_dataset( .splitlines() ) - train_set = [ - [src, tgt, lang] - for src, tgt, lang in zip(train_src_set, train_tgt_set, train_lang_set) - if len(src) < max_length - and len(tgt) < max_length - and len(src.replace(" ", "")) != 0 - and len(tgt.replace(" ", "")) != 0 - and lang[0:5] == lang[6:] - ] + # train_set = [ + # [src, tgt, lang] + # for src, tgt, lang in zip(train_src_set, train_tgt_set, train_lang_set) + # if len(src) < max_length + # and len(tgt) < max_length + # and len(src.replace(" ", "")) != 0 + # and len(tgt.replace(" ", "")) != 0 + # and lang[0:5] == lang[6:] + # ] + train_set=[] + ko = 0 + en = 0 + zh = 0 + total = 0 + for src, tgt, lang in zip(train_src_set, train_tgt_set, train_lang_set): + if len(src) < max_length and len(tgt) < max_length and len(src.replace(" ", "")) != 0 and len(tgt.replace(" ", "")) != 0 and lang[0:5] == lang[6:]: + if lang == "ko_KR": + ko +=1 + elif lang == "en_XX": + en += 1 + elif lang == "zh_CN": + zh += 1 + total+=1 + train_set.append([src,tgt,lang]) + + + # ko = ko / total + # en = en / total + # zh = zh / total + + + ko=get_pawsx_data(dataset_path,"ko", "train",ko) + en=get_pawsx_data(dataset_path,"en", "train",en) + zh=get_pawsx_data(dataset_path,"zh", "train",zh) + + lang_sample_count = np.array([ko, en, zh]) + weight = 1. / lang_sample_count + + train_weight = [] + for _, _, lang in train_set: + if lang == "ko_KR": + train_weight.append(weight[0]) + elif lang == "en_XX": + train_weight.append(weight[1]) + else : + train_weight.append(weight[2]) + train_weight = np.array(train_weight) + train_weight = torch.from_numpy(train_weight) + train_weight = train_weight.double() + + sampler = WeightedRandomSampler(train_weight, len(train_weight)) - get_pawsx_data(dataset_path,"ko" " train") - get_pawsx_data(dataset_path,"en", "train") - get_pawsx_data(dataset_path,"zh", "train") get_pawsx_data(dataset_path,"ko", "validation") get_pawsx_data(dataset_path,"en", "validation") get_pawsx_data(dataset_path,"zh", "validation") train_set.extend(paws_train_set) - train_loader = DataLoader( train_set, batch_size=batch_size, - shuffle=True, + shuffle=False, num_workers=8, pin_memory=True, + sampler = sampler + ) val_loader = DataLoader( @@ -71,7 +111,7 @@ def load_multilingual_dataset( batch_size=batch_size, shuffle=True, num_workers=8, - pin_memory=True + pin_memory=True, ) @@ -80,30 +120,26 @@ def load_multilingual_dataset( return train_loader, val_loader -def get_test_data(dataset_path,batch_size): +def get_test_data(dataset_path): get_pawsx_data(dataset_path,"ko", "test") get_pawsx_data(dataset_path,"en", "test") get_pawsx_data(dataset_path,"zh", "test") - test_loader = DataLoader( - paws_test_set, - batch_size=batch_size, - num_workers=8, - pin_memory=True - ) - return test_loader + return paws_test_set -def get_pawsx_data(dataset_path : str,lang : str, mode: str): +def get_pawsx_data(dataset_path : str,lang : str, mode: str, count = None): """ load paws_x dataset and create DataLoader Args: lang (str): language that you want to make dataset Ex) ko -> Korean, en-> English, zh -> chinese mode (int): the mode that you want to choose Ex) train, validation """ - for val in (open(f"{os.getcwd()}/{dataset_path}/{lang}/{mode}.tsv", "r", encoding="utf8").read().splitlines()): + for val in (open(f"{dataset_path}/{lang}/{mode}.tsv", "r", encoding="utf8").read().splitlines()): + if len(val.split('\t'))!=4: + continue _, s1, s2, label = val.split('\t') if label == 0 or s1 == 'sentence1': continue @@ -124,7 +160,11 @@ def get_pawsx_data(dataset_path : str,lang : str, mode: str): paws_test_set.append([s1, s2, lang_code]) else: raise ValueError("parameter mode only allow \"test\", \"validation\", \"test\"") + if count is not None: + count += 1 + if count is not None: + return count diff --git a/main/run.py b/main/run.py index 44b083e..18d6b3b 100644 --- a/main/run.py +++ b/main/run.py @@ -20,8 +20,12 @@ mode="min", ), ], - progress_bar_refresh_rate=20 + progress_bar_refresh_rate=20, + ) + train_dataloader, validation_dataloader = load_multilingual_dataset( + dataset_path=f"{os.getcwd()}/drive/MyDrive/dataset", batch_size=4 ) - train_dataloader, validation_dataloader = load_multilingual_dataset(dataset_path=f'{os.getcwd()}/drive/MyDrive/dataset', batch_size=4) model = DistillBart(12, 3) - trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=validation_dataloader) + trainer.fit( + model, train_dataloader=train_dataloader, val_dataloaders=validation_dataloader + ) diff --git a/main/test.py b/main/test.py index 635c1f0..b3c38c5 100644 --- a/main/test.py +++ b/main/test.py @@ -11,46 +11,49 @@ print("test_start") print("model_loading") check_list = [ - "paraphrase_mlbart_epoch=00-val_loss=0.14.ckpt", - "paraphrase_mlbart_epoch=01-val_loss=0.25.ckpt", - "paraphrase_mlbart_epoch=02-val_loss=0.34.ckpt", - "paraphrase_mlbart_epoch=03-val_loss=0.29.ckpt" + "paraphrase_mlbart_epoch=00-val_loss=0.14.ckpt", + "paraphrase_mlbart_epoch=01-val_loss=0.25.ckpt", + "paraphrase_mlbart_epoch=02-val_loss=0.34.ckpt", + "paraphrase_mlbart_epoch=03-val_loss=0.29.ckpt", ] for i in range(4): - path = check_list[i] - model = DistillBart().load_from_checkpoint(checkpoint_path=f"drive/MyDrive/mlbart_ckpt/{check_list[i]}").model - print("model_loaded") + path = check_list[i] + model = ( + DistillBart() + .load_from_checkpoint( + checkpoint_path=f"drive/MyDrive/mlbart_ckpt/{check_list[i]}" + ) + .model + ) + print("model_loaded") - print("tokenizer_loading") - tokenizer = AsianBartTokenizer.from_pretrained("hyunwoongko/asian-bart-ecjk") - print("tokenizer_loaded") + print("tokenizer_loading") + tokenizer = AsianBartTokenizer.from_pretrained("hyunwoongko/asian-bart-ecjk") + print("tokenizer_loaded") + print("dataset_loading") + dataset = get_test_data("drive/MyDrive/dataset/") + print("dataset_loaded") - print("dataset_loading") - dataset = get_test_data('drive/MyDrive/dataset/') - print("dataset_loaded") - - src_list = [] - src_lang_list = [] - print(f"{i+1}'s turn") - print(f"{check_list[i]} file opend") - f = open(f"{check_list[i]}gen{i}.txt", "w") - label = open('label.txt','w') - print("*****************") - print("Writing Start!") - print("*****************") - for s1,s2,lang_code in dataset: - inputs = tokenizer.prepare_seq2seq_batch( - src_texts=s1, - src_langs=lang_code - ) - gen_token=model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code]) - # import pdb;pdb.set_trace() - f.write(tokenizer.decode(gen_token[0][2:], skip_special_tokens=True)+"\n") - label.write(s2+"\n") - print("*****************") - print("Writing end!") - print("*****************") - f.close() - label.close() - + src_list = [] + src_lang_list = [] + print(f"{i+1}'s turn") + print(f"{check_list[i]} file opend") + f = open(f"{check_list[i]}gen{i}.txt", "w") + label = open("label.txt", "w") + print("*****************") + print("Writing Start!") + print("*****************") + for s1, s2, lang_code in dataset: + inputs = tokenizer.prepare_seq2seq_batch(src_texts=s1, src_langs=lang_code) + gen_token = model.generate( + **inputs, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code] + ) + # import pdb;pdb.set_trace() + f.write(tokenizer.decode(gen_token[0][2:], skip_special_tokens=True) + "\n") + label.write(s2 + "\n") + print("*****************") + print("Writing end!") + print("*****************") + f.close() + label.close()