-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathconfig.py
49 lines (32 loc) · 1.54 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import json
with open("config.json", "r") as jsonfile:
config = json.load(jsonfile)
class TIMITConfig(object):
dir = config['dataDir']['dir']
# path to the unzipped TIMIT data folder
data_path = config['dataDir']['data_path'].replace('$dir', dir)
# path to csv file containing age, heights of timit speakers
speaker_csv_path = config['dataDir']['speaker_csv_path'].replace('$dir', dir)
batch_size = int(config['model_parameters']['batch_size'])
epochs = int(config['model_parameters']['epochs'])
model_type = config['model_parameters']['model_type']
# UncertaintyLoss
loss = "UncertaintyLoss"
# upstream model to be loaded from s3prl. Some of the upstream models are: wav2vec2, TERA, mockingjay etc.
#See the available models here: https://github.com/s3prl/s3prl/blob/master/s3prl/upstream/README.md
upstream_model = config['model_parameters']['upstream_model']
# number of layers in encoder (transformers)
num_layers = 6
# feature dimension of upstream model. For example,
# For wav2vec2, feature_dim = 768
feature_dim = int(config['model_parameters']['feature_dim'])
# No of GPUs for training and no of workers for datalaoders
gpu = int(config['gpu'])
n_workers = int(config['n_workers'])
# model checkpoint to continue from
model_checkpoint = None
# LR of optimizer
lr = float(config['model_parameters']['lr'])
narrow_band = config['model_parameters']['narrow_band']
run_name = 'multi-task' + '_' + model_type