-
Notifications
You must be signed in to change notification settings - Fork 2
/
deeplabv3.py
119 lines (85 loc) · 3.21 KB
/
deeplabv3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Any
import torch.nn as nn
from cvm import models
from ..ops import blocks
from ..utils import export, get_out_channels, load_from_local_or_url
from .heads import FCNHead, ClsHead
from .segmentation_model import SegmentationModel
class DeepLabHead(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int = 256,
num_classes: int = 32,
):
super().__init__(
blocks.ASPP(in_channels, out_channels, [12, 24, 36]),
blocks.Conv2dBlock(out_channels, out_channels),
blocks.Conv2d1x1(out_channels, num_classes)
)
@export
class DeepLabV3(SegmentationModel):
...
@export
def create_deeplabv3(
backbone: str = 'resnet50_v1',
num_classes: int = 21,
aux_loss: bool = False,
cls_loss: bool = False,
dropout_rate: float = 0.1,
pretrained_backbone: bool = False,
pretrained: bool = False,
pth: str = None,
progress: bool = True,
**kwargs: Any
):
if pretrained:
pretrained_backbone = False
backbone = models.__dict__[backbone](
pretrained=pretrained_backbone,
dilations=[1, 1, 2, 4],
**kwargs
).features
aux_head = FCNHead(get_out_channels(backbone.stage3), None, num_classes, dropout_rate) if aux_loss else None
cls_head = ClsHead(get_out_channels(backbone.stage4), num_classes) if cls_loss else None
decode_head = DeepLabHead(get_out_channels(backbone.stage4), num_classes=num_classes)
model = DeepLabV3(backbone, [3, 4] if aux_loss else [4], decode_head, aux_head, cls_head)
if pretrained:
load_from_local_or_url(model, pth, kwargs.get('url', None), progress)
return model
@export
def deeplabv3_resnet50_v1(*args, **kwargs: Any):
return create_deeplabv3('resnet50_v1', *args, **kwargs)
@export
def deeplabv3_mobilenet_v3_small(*args, **kwargs: Any):
return create_deeplabv3('mobilenet_v3_small', *args, **kwargs)
@export
def deeplabv3_mobilenet_v3_large(*args, **kwargs: Any):
return create_deeplabv3('mobilenet_v3_large', *args, **kwargs)
@export
def deeplabv3_regnet_x_400mf(*args, **kwargs: Any):
return create_deeplabv3('regnet_x_400mf', *args, **kwargs)
@export
def deeplabv3_mobilenet_v1_x1_0(*args, **kwargs: Any):
return create_deeplabv3('mobilenet_v1_x1_0', *args, **kwargs)
@export
def deeplabv3_sd_mobilenet_v1_x1_0(*args, **kwargs: Any):
return create_deeplabv3('sd_mobilenet_v1_x1_0', *args, **kwargs)
@export
def deeplabv3_mobilenet_v2_x1_0(*args, **kwargs: Any):
return create_deeplabv3('mobilenet_v2_x1_0', *args, **kwargs)
@export
def deeplabv3_sd_mobilenet_v2_x1_0(*args, **kwargs: Any):
return create_deeplabv3('sd_mobilenet_v2_x1_0', *args, **kwargs)
@export
def deeplabv3_shufflenet_v2_x2_0(*args, **kwargs: Any):
return create_deeplabv3('shufflenet_v2_x2_0', *args, **kwargs)
@export
def deeplabv3_sd_shufflenet_v2_x2_0(*args, **kwargs: Any):
return create_deeplabv3('sd_shufflenet_v2_x2_0', *args, **kwargs)
@export
def deeplabv3_efficientnet_b0(*args, **kwargs: Any):
return create_deeplabv3('efficientnet_b0', *args, **kwargs)
@export
def deeplabv3_sd_efficientnet_b0(*args, **kwargs: Any):
return create_deeplabv3('sd_efficientnet_b0', *args, **kwargs)