-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
34 lines (29 loc) · 884 Bytes
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
def get_config(img_path: str, cap_path: str):
return {
'debug': False,
'image_path':img_path,
'caption_path': cap_path,
'batch_size': 32,
'num_workers': 4,
'lr': 1e-3,
'image_encoder_lr': 1e-4,
'text_encoder_lr': 1e-5,
'weight_decay': 1e-3,
'patience': 1,
'factor': 0.8,
'epoch': 4,
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
'img_model': 'resnet50',
'img_embedding_dim': 2048,
'txt_embedding_dim': 768,
'txt_model': 'distilbert-base-uncased',
'max_length': 200,
'pretrained': True,
'trainable': True,
'temprature': 1.0,
'size': 224,
'num_projection_layers': 1,
'projection_dim': 256,
'dropout': 0.1
}