-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_wrn.py
55 lines (48 loc) · 1.34 KB
/
model_wrn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import torch
from wrn1d import WideResNet1d
from wrn2d import WideResNet2d
from wrn3d import WideResNet3d
# seeding randomness for reproducibility
np.random.seed(42)
torch.manual_seed(1)
def get_wrn(
input_shape,
output_dim,
output_shape,
in_channels,
device=None,
depth=16,
widen_factor=4,
dropRate=0.0,
):
"""Init correct wrn
"""
kwargs = {
'depth': depth,
'num_classes': output_dim,
'input_shape': input_shape,
'widen_factor': widen_factor,
'dropRate': dropRate,
'in_channels': in_channels,
'output_shape': output_shape,
}
spacetime_dims = np.count_nonzero(np.array(input_shape)[[0, 2, 3]] != 1)
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using WRN of dimension {spacetime_dims}")
if spacetime_dims == 1:
model_to_use = WideResNet1d
# kwargs['dropRate'] = 0.4
elif spacetime_dims == 2:
model_to_use = WideResNet2d
elif spacetime_dims == 3:
model_to_use = WideResNet3d
elif spacetime_dims == 0: # Special case where we have channels only
model_to_use = WideResNet1d
kwargs['in_channels'] = 1
else:
raise NotImplementedError
model = model_to_use(**kwargs).to(device)
print(model)
return model