-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodels.py
121 lines (91 loc) · 3.43 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import timm
import torch.nn.functional as F
import os
from utils import update_model_config
class TripletModel(nn.Module):
"""
Custom PyTorch model for the triplet classification task.
Parameters:
- CFG (object): Configuration object containing hyperparameters.
- model_name (str): Name of the pre-trained model architecture.
- pretrained (bool, optional): Flag indicating whether to use pre-trained weights. Default is True.
Attributes:
- model (nn.Module): Pre-trained model backbone.
- n_features (int): Number of features in the final embedding.
Methods:
- forward(x): Forward pass through the model.
"""
def __init__(self, CFG, model_name, pretrained=True):
super().__init__()
"""
Models class to return swin transformer models
"""
# Load the backbone
self.model = timm.create_model(model_name, pretrained=pretrained)
if CFG.local_weight:
self.model.load_state_dict(
torch.load(
f"{CFG.weight_dir}/swin_base_patch4_window7_224_22kto1k.pth"
)["model"]
)
# Get the number features in final embedding
n_features = self.model.head.in_features
# Update the classification layer with our custom target size
self.model.head = nn.Linear(n_features, CFG.target_size)
def forward(self, x):
"""
Forward pass through the model.
Parameters:
- x (torch.Tensor): Input tensor.
Returns:
- torch.Tensor: Output tensor.
"""
x = self.model(x)
return x
def get_pretrained_model(fold, CFG):
"""
Load a pretrained model or custom weights based on the specified configuration.
Args:
fold (int): The fold number.
CFG (config object): Configuration object containing model settings.
Returns:
torch.nn.Module: The loaded model.
"""
# Available pretrained models
pretrained_models = [
"SwinT",
"SwinT+MultiT",
"SwinT+SelfDv2",
"SwinT+MultiT+SelfD",
"+phase",
"SwinLarge",
]
# Update target size and model name if pretrained_model is True
if CFG.pretrained_model:
update_model_config(CFG)
# Initialize the model
model = TripletModel(CFG, model_name=CFG.model_name, pretrained=False).to(
CFG.device
)
# Download pretrained weights or load custom weights
if CFG.pretrained_model:
if CFG.exp not in pretrained_models:
raise Exception(
f"Requested model: exp={CFG.exp} is not available, please select one of the available models:\n{pretrained_models}"
)
# Update the fold and exp tag to match the experiment
checkpoint_url=f"https://self-distillation-weights.s3.dkfz.de/fold{fold}_{CFG.exp}.pth"
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, progress=True)
model.load_state_dict(checkpoint["model"])
print(f"fold {fold}: Pretrained Weights downloaded and loaded successfully")
else:
# Load your custom weights
weights_path = os.path.join(
CFG.output_dir,
f"checkpoints/fold{fold}_{CFG.model_name[:8]}_{CFG.target_size}_{CFG.exp}.pth",
)
model.load_state_dict(torch.load(weights_path)["model"])
print(f"fold {fold}: Weights loaded successfully")
return model