-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
57 lines (47 loc) · 2.1 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import timm
from timm import create_model
def model_chest_xray(model_name ,init_weight, num_class, in_chans):
"""create a pre-built model using the timm module
Args:
model (str): model name
init_weight (bool): initialiaze the model with pre-training weights
num_class (init): number of classes
in_chans (init): number of channels for input data
Returns:
torch.model: returns the torch model
"""
model = None
print("Creating model...")
model = create_model(model_name=model_name,
pretrained=init_weight,
num_classes=num_class,
in_chans=in_chans)
return model
def build_model(args):
model = None
pretrained = True if args.init=="ImageNet" else False
print(f"Creating model {args.model_name} with {args.init} weights.....")
if args.model_name == "resnet18":
model = create_model(model_name=args.model_name,
pretrained=pretrained,
num_classes=args.num_classes,
in_chans=args.in_chans)
elif args.model_name == "resnet50":
model = create_model(model_name=args.model_name,
pretrained=pretrained,
num_classes=args.num_classes,
in_chans=args.in_chans)
elif args.model_name == "swin_tiny":
model = create_model(model_name="swin_tiny_patch4_window7_224",
pretrained=pretrained,
num_classes=args.num_classes,
in_chans=args.in_chans)
elif args.model_name == "swin_base":
model = create_model(model_name="swin_base_patch4_window7_224",
pretrained=pretrained,
num_classes=args.num_classes,
in_chans=args.in_chans)
else:
print(f"Not implemented for {args.model_name} model.")
raise Exception("Please provide correct model name to build!")
return model