-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_byol.py
64 lines (45 loc) · 2.28 KB
/
train_byol.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import torch
import yaml
from torchvision import datasets
from data.multi_view_data_injector import MultiViewDataInjector
from data.transforms import get_simclr_data_transforms
from models.mlp_head import MLPHead
from models.resnet_base_network import ResNet18
from trainer.byol_trainer import BYOLTrainer
def main():
config = yaml.load(open("./config/config.yaml", "r"), Loader=yaml.FullLoader)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training with: {device}")
data_transform = get_simclr_data_transforms(**config['data_transforms'])
train_dataset = datasets.STL10('/home/thalles/Downloads/', split='train+unlabeled', download=True,
transform=MultiViewDataInjector([data_transform, data_transform]))
# online network
online_network = ResNet18(**config['network']).to(device)
pretrained_folder = config['network']['fine_tune_from']
# load pre-trained model if defined
if pretrained_folder:
try:
checkpoints_folder = os.path.join('./runs', pretrained_folder, 'checkpoints')
# load pre-trained parameters
load_params = torch.load(os.path.join(os.path.join(checkpoints_folder, 'model.pth')),
map_location=torch.device(torch.device(device)))
online_network.load_state_dict(load_params['online_network_state_dict'])
except FileNotFoundError:
print("Pre-trained weights not found. Training from scratch.")
# predictor network
predictor = MLPHead(in_channels=online_network.projetion.net[-1].out_features,
**config['network']['projection_head']).to(device)
# target encoder
target_network = ResNet18(**config['network']).to(device)
optimizer = torch.optim.SGD(list(online_network.parameters()) + list(predictor.parameters()),
**config['optimizer']['params'])
trainer = BYOLTrainer(online_network=online_network,
target_network=target_network,
optimizer=optimizer,
predictor=predictor,
device=device,
**config['trainer'])
trainer.train(train_dataset)
if __name__ == '__main__':
main()