forked from ServiceNow/HighRes-net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVGGFeatureExtractor.py
74 lines (62 loc) · 2.68 KB
/
VGGFeatureExtractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from torchvision.models import vgg19
from torchvision import transforms
import torch
import numpy as np
#[0.485, 0.456, 0.406], # ImageNet means
#[0.229, 0.224, 0.225]
# Define the VGGFeatureExtractor with updated weights parameter
class VGGFeatureExtractor(torch.nn.Module):
layers=['0', '5', '10', '19', '28']
def __init__(self):
super(VGGFeatureExtractor, self).__init__()
from torchvision.models import VGG19_Weights
weights = VGG19_Weights.DEFAULT
self.preprocess = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize(mean=weights.transforms().mean,
std=weights.transforms().std)])
self.vgg = vgg19(weights=weights).features[:int(VGGFeatureExtractor.layers[-1])+1]
#weights_path="./vgg19_conv_layers.pth"
#self.layers = layers
#self.vgg = vgg19().features[:29] # Match truncated layers
# Load weights from the saved file
#self.vgg.load_state_dict(torch.load(weights_path, map_location="cpu"))
if torch.cuda.is_available():
self.hardware = 'cuda'
print("VGG using CUDA")
else:
self.hardware = 'cpu'
def forward(self, x):
x = self.convert_grayscale_to_input_tensor(x).to(self.hardware)
outputs = {}
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in VGGFeatureExtractor.layers:
outputs[name] = x
return outputs
def convert_grayscale_to_input_tensor(self, x):
'''
Convert a grayscale image to a 4-channel input tensor: [1, 3, H, W]
'''
if x.ndim == 2:
# Grayscale image, replicate channels to make (H, W, 3)
x = np.stack([x, x, x], axis=-1)
elif x.ndim == 3 and x.shape[2] == 1:
# Single-channel image with shape (H, W, 1), replicate to (H, W, 3)
x = np.concatenate([x, x, x], axis=2)
elif x.ndim == 3 and x.shape[2] == 3:
pass
else:
raise ValueError("Input image must have shape (H, W), (H, W, 1), or (H, W, 3)")
x = self.preprocess(x).unsqueeze(0) # Shape: [1, 3, H, W]
return x
@staticmethod
def init_VGG_for_perceptual_loss():
# Initialize the feature extractor
feature_extractor = VGGFeatureExtractor().to('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor.eval()
# Disable gradient computations
for param in feature_extractor.parameters():
param.requires_grad = False
return feature_extractor