-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
33 lines (30 loc) · 1.08 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from pathlib import Path
def get_config():
return {
"batch_size": 16,
"num_epochs": 5,
"lr": 1e-4,
"seq_len": 40,
"d_model": 512,
"datasource": "filtered_mt_en_vi.parquet",
"lang_src": "en",
"lang_tgt": "vi",
"model_folder": "weights",
"model_basename": "tmodel_",
"preload": "latest",
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "runs/tmodel"
}
def get_weights_file_path(config, epoch: str):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}{epoch}.pt"
return str(Path('.') / model_folder / model_filename)
# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}*"
weights_files = list(Path(model_folder).glob(model_filename))
if len(weights_files) == 0:
return None
weights_files.sort()
return str(weights_files[-1])