-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
168 lines (145 loc) · 6.71 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
from torch.utils.data import Dataset
import PIL.Image
from torchvision import transforms
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import random
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
torch.manual_seed(17)
np.random.seed(0)
transform_grd = transforms.Compose([
transforms.Resize([320, 640]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_sat = transforms.Compose([
transforms.Resize([512, 512]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class VIGORDataset(Dataset):
def __init__(self, root='./datasets/VIGOR', label_root = 'splits_new', transform=None, load_gt=True, known_ori=True):
self.root = root
self.label_root = label_root
self.known_ori = known_ori
if transform != None:
self.grdimage_transform = transform[0]
self.satimage_transform = transform[1]
self.city_list = ['SanFrancisco', 'Chicago']
# load sat list
self.sat_list = []
self.sat_index_dict = {}
idx = 0
for city in self.city_list:
sat_list_fname = os.path.join(self.root, label_root, city, 'satellite_list.txt')
with open(sat_list_fname, 'r') as file:
for line in file.readlines():
self.sat_list.append(os.path.join(self.root, city, 'satellite', line.replace('\n', '')))
self.sat_index_dict[line.replace('\n', '')] = idx
idx += 1
print('InputData::__init__: load', sat_list_fname, idx)
self.sat_list = np.array(self.sat_list)
self.sat_data_size = len(self.sat_list)
print('Sat loaded, data size:{}'.format(self.sat_data_size))
# load grd list
self.grd_list = []
self.label = []
self.sat_cover_dict = {}
self.gt_delta = []
idx = 0
for city in self.city_list:
# load grd panorama list
label_fname = os.path.join(self.root, self.label_root, city, 'pano_label_balanced.txt')
with open(label_fname, 'r') as file:
for line in file.readlines():
data = np.array(line.split(' '))
label = []
for i in [1, 4, 7, 10]:
label.append(self.sat_index_dict[data[i]])
label = np.array(label).astype(int)
delta = np.array([data[2:4], data[5:7], data[8:10], data[11:13]]).astype(float)
self.grd_list.append(os.path.join(self.root, city, 'panorama', data[0]))
self.label.append(label)
self.gt_delta.append(delta)
if not label[0] in self.sat_cover_dict:
self.sat_cover_dict[label[0]] = [idx]
else:
self.sat_cover_dict[label[0]].append(idx)
idx += 1
print('InputData::__init__: load ', label_fname, idx)
self.data_size = len(self.grd_list)
print('Grd loaded, data size:{}'.format(self.data_size))
self.label = np.array(self.label)
self.gt_delta = np.array(self.gt_delta)
self.pred_delta = np.zeros(np.shape(self.gt_delta))
self.load_gt = load_gt
self.predefined_random_rot = None
def __len__(self):
return self.data_size
def __getitem__(self, idx):
# full ground panorama
try:
grd = PIL.Image.open(os.path.join(self.grd_list[idx]))
grd = grd.convert('RGB')
except:
print('unreadable image')
grd = PIL.Image.new('RGB', (320, 640))
grd = self.grdimage_transform(grd)
# generate a random rotation
if self.known_ori:
rotation = 0
else:
if self.predefined_random_rot is None:
rotation = np.random.uniform(low=0.0, high=1.0)
else:
rotation = self.predefined_random_rot[idx]
grd = torch.roll(grd, (torch.round(torch.as_tensor(rotation)*grd.size()[2]).int()).item(), dims=2)
orientation_angle = rotation * 360 # 0 means heading North, counter-clockwise increasing
if orientation_angle < 0:
orientation_angle += 360
# satellite
pos_index = 0
sat = PIL.Image.open(os.path.join(self.sat_list[self.label[idx][pos_index]]))
if self.load_gt == True:
[row_offset, col_offset] = self.gt_delta[idx, pos_index] # delta = [delta_lat, delta_lon]
else:
[row_offset, col_offset] = self.pred_delta[idx, pos_index] # delta = [delta_lat, delta_lon]
sat = sat.convert('RGB')
width_raw, height_raw = sat.size
sat = self.satimage_transform(sat)
_, height, width = sat.size()
row_offset = np.round(row_offset/height_raw*height)
col_offset = np.round(col_offset/width_raw*width)
# groundtruth location on the satellite map
# Gaussian GT
gt = np.zeros([1, height, width], dtype=np.float32)
gt_with_ori = np.zeros([20, height, width], dtype=np.float32)
x, y = np.meshgrid(np.linspace(-width/2+col_offset,width/2+col_offset,width), np.linspace(-height/2-row_offset,height/2-row_offset,height))
d = np.sqrt(x*x+y*y)
sigma, mu = 4, 0.0
gt[0, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) )
gt = torch.tensor(gt)
index = int(orientation_angle // 18)
ratio = (orientation_angle % 18) / 18
if index == 0:
gt_with_ori[0, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) ) * (1-ratio)
gt_with_ori[19, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) ) * ratio
else:
gt_with_ori[20-index, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) ) * (1-ratio)
gt_with_ori[20-index-1, :, :] = np.exp(-( (d-mu)**2 / ( 2.0 * sigma**2 ) ) ) * ratio
gt_with_ori = torch.tensor(gt_with_ori)
orientation = torch.full([2, height, width], np.cos(orientation_angle * np.pi/180))
orientation[1,:,:] = np.sin(orientation_angle * np.pi/180)
if 'NewYork' in self.grd_list[idx]:
city = 'NewYork'
elif 'Seattle' in self.grd_list[idx]:
city = 'Seattle'
elif 'SanFrancisco' in self.grd_list[idx]:
city = 'SanFrancisco'
elif 'Chicago' in self.grd_list[idx]:
city = 'Chicago'
return grd, sat, gt, gt_with_ori, orientation, city