-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
60 lines (47 loc) · 2.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from src import BaseTokenizer
from utils.app_settings import available_tokenizers
from utils.settings import *
from datasets import load_dataset
from pathlib import Path
import argparse
def main():
parser = argparse.ArgumentParser(
prog='Tokenizer Trainer',
description='A tokenizer trainer which train different models from different implementations',
epilog=''
)
parser.add_argument('-n', '--tokenizer_name', default="custom_bpe", choices=available_tokenizers.keys())
parser.add_argument('-d', '--directory', default=DATA_FOLDER)
parser.add_argument('-f', '--vocab_file', default=VOCAB_FILE)
parser.add_argument('-m', '--model_file', required=False)
parser.add_argument('-s', '--vocab_size', default=VOCAB_SIZE)
args = parser.parse_args()
assert args.tokenizer_name in available_tokenizers.keys(), f"{args.tokenizer_name} is not a tokenizer implemented which can be trained"
tokenizer_name = args.tokenizer_name
dir_name = Path(args.directory)
vocab_file = Path(args.vocab_file)
vocab_size = int(args.vocab_size)
model_file = Path(args.model_file) if args.model_file else None
print(f"{args.tokenizer_name} training starts...")
print(f"Directory:", args.directory)
print("Dataset:", "wikipedia --20220301.en")
wikipedia_dataset = load_dataset("wikipedia", "20220301.en", trust_remote_code=True)
wiki_set = wikipedia_dataset['train']
data_size = 1_000 if tokenizer_name=="bpe_custom" else 25_000
set_for_train = [ text for text in wiki_set[:data_size]["text"]]
tk: BaseTokenizer = available_tokenizers.get(tokenizer_name)(
directory=dir_name.joinpath(tokenizer_name),
vocab_file=vocab_file,
vocab_size=vocab_size
)
tk_params = {
"directory": dir_name.joinpath(tokenizer_name),
"vocab_file": vocab_file,
"vocab_size": vocab_size
}
tk.train(set_for_train, verbose=False)
# tk.register_special_tokens(CONTROL_TOKENS_LIST)
# tk.save()
if __name__ == "__main__":
main()
print("Job's done")