-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathdataset.py
More file actions
346 lines (285 loc) · 13.8 KB
/
dataset.py
File metadata and controls
346 lines (285 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import gzip
import logging
import os
import pickle
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import yaml
import torch
from tqdm import tqdm
from navsim.common.dataclasses import AgentInput
from navsim.common.dataloader import SceneLoader
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
logger = logging.getLogger(__name__)
def load_feature_target_from_pickle(path: Path) -> Dict[str, torch.Tensor]:
"""Helper function to load pickled feature/target from path."""
with gzip.open(path, "rb") as f:
data_dict: Dict[str, torch.Tensor] = pickle.load(f)
return data_dict
def dump_feature_target_to_pickle(path: Path, data_dict: Dict[str, torch.Tensor]) -> None:
"""Helper function to save feature/target to pickle."""
# Use compresslevel = 1 to compress the size but also has fast write and read.
with gzip.open(path, "wb", compresslevel=1) as f:
pickle.dump(data_dict, f)
class CacheOnlyDataset(torch.utils.data.Dataset):
"""Dataset wrapper for feature/target datasets from cache only."""
def __init__(
self,
cache_path: str,
feature_builders: List[AbstractFeatureBuilder],
target_builders: List[AbstractTargetBuilder],
log_names: Optional[List[str]] = None,
split: str = "train",
):
"""
Initializes the dataset module.
:param cache_path: directory to cache folder
:param feature_builders: list of feature builders
:param target_builders: list of target builders
:param log_names: optional list of log folder to consider, defaults to None
"""
super().__init__()
assert Path(cache_path).is_dir(), f"Cache path {cache_path} does not exist!"
self._cache_path = Path(cache_path)
if log_names is not None:
self.log_names = [Path(log_name) for log_name in log_names if (self._cache_path / log_name).is_dir()]
else:
self.log_names = [log_name for log_name in self._cache_path.iterdir()]
self._feature_builders = feature_builders
self._target_builders = target_builders
self.split = split# NOTE split type, 'train' or 'val'
self._valid_cache_paths: Dict[str, Path] = self._load_valid_caches(
cache_path=self._cache_path,
feature_builders=self._feature_builders,
target_builders=self._target_builders,
log_names=self.log_names,
)
self.tokens = list(self._valid_cache_paths.keys())
def __len__(self) -> int:
"""
:return: number of samples to load
"""
return len(self.tokens)
def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Loads and returns pair of feature and target dict from data.
:param idx: index of sample to load.
:return: tuple of feature and target dictionary
"""
return self._load_scene_with_token(self.tokens[idx])
# @staticmethod
def _load_valid_caches(
self,
cache_path: Path,
feature_builders: List[AbstractFeatureBuilder],
target_builders: List[AbstractTargetBuilder],
log_names: List[Path],
) -> Dict[str, Path]:
"""
Helper method to load valid cache paths.
:param cache_path: directory of training cache folder
:param feature_builders: list of feature builders
:param target_builders: list of target builders
:param log_names: list of log paths to load
:return: dictionary of tokens and sample paths as keys / values
"""
valid_cache_paths: Dict[str, Path] = {}
for log_name in tqdm(log_names, desc="Loading Valid Caches"):
log_path = cache_path / log_name
for token_path in log_path.iterdir():
found_caches: List[bool] = []
for builder in feature_builders + target_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
found_caches.append(data_dict_path.is_file())
if all(found_caches):
valid_cache_paths[token_path.name] = token_path
# for syn data
SYN_IDX = os.environ.get('SYN_IDX', None)
SYN_GT = os.environ.get('SYN_GT', None)
yaml_base = Path(os.environ.get('NAVSIM_DEVKIT_ROOT')) / 'navsim/planning/script/config/training'
cache_base = Path(os.environ.get('NAVSIM_EXP_ROOT')) / 'cache'
if SYN_IDX is None:
print("SYN_IDX is not set, using default log data.")
return valid_cache_paths
SYN_IDX = int(SYN_IDX)
for idx in range(0, SYN_IDX + 1):
# 1 add logs
yaml_file = yaml_base / f'default_log_split_synthetic_reation_{SYN_GT}_v1.0-{idx}.yaml'
syn_log_names = []
with open(yaml_file, 'r') as f:
syn_log_names = yaml.safe_load(f)[f'{self.split}_logs']
self.log_names.append(syn_log_names)
# 2 add cache data path list
cache_path = cache_base / f'navtrain_reaction_{SYN_GT}_v1.0-{idx}_cache'
for log_name in tqdm(syn_log_names, desc=f"Split {self.split} Loading Syn Caches for {SYN_GT}_v1.0-{idx}"):
log_path = cache_path / log_name
for token_path in log_path.iterdir():
found_caches: List[bool] = []
for builder in feature_builders + target_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
found_caches.append(data_dict_path.is_file())
if all(found_caches):
valid_cache_paths[token_path.name] = token_path
return valid_cache_paths
def _load_scene_with_token(self, token: str) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Helper method to load sample tensors given token
:param token: unique string identifier of sample
:return: tuple of feature and target dictionaries
"""
token_path = self._valid_cache_paths[token]
features: Dict[str, torch.Tensor] = {}
for builder in self._feature_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
data_dict = load_feature_target_from_pickle(data_dict_path)
features.update(data_dict)
targets: Dict[str, torch.Tensor] = {}
for builder in self._target_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
data_dict = load_feature_target_from_pickle(data_dict_path)
targets.update(data_dict)
return (features, targets, token)
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
scene_loader: SceneLoader,
feature_builders: List[AbstractFeatureBuilder],
target_builders: List[AbstractTargetBuilder],
cache_path: Optional[str] = None,
force_cache_computation: bool = False,
append_token_to_batch: bool = False,
agent_input_only: bool = False,
is_training: bool = True
):
super().__init__()
self.is_training = is_training
self.append_token_to_batch = append_token_to_batch
self.agent_input_only = agent_input_only
self._scene_loader = scene_loader
self._feature_builders = feature_builders
self._target_builders = target_builders
self._cache_path: Optional[Path] = Path(cache_path) if cache_path else None
self._force_cache_computation = force_cache_computation
self._valid_cache_paths: Dict[str, Path] = self._load_valid_caches(
self._cache_path, feature_builders, target_builders
)
if self._cache_path is not None:
self.cache_dataset()
@staticmethod
def _load_valid_caches(
cache_path: Optional[Path],
feature_builders: List[AbstractFeatureBuilder],
target_builders: List[AbstractTargetBuilder],
) -> Dict[str, Path]:
"""
Helper method to load valid cache paths.
:param cache_path: directory of training cache folder
:param feature_builders: list of feature builders
:param target_builders: list of target builders
:return: dictionary of tokens and sample paths as keys / values
"""
valid_cache_paths: Dict[str, Path] = {}
if (cache_path is not None) and cache_path.is_dir():
for log_path in cache_path.iterdir():
for token_path in log_path.iterdir():
found_caches: List[bool] = []
for builder in feature_builders + target_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
found_caches.append(data_dict_path.is_file())
if all(found_caches):
valid_cache_paths[token_path.name] = token_path
return valid_cache_paths
def _cache_scene_with_token(self, token: str) -> None:
"""
Helper function to compute feature / targets and save in cache.
:param token: unique identifier of scene to cache
"""
scene = self._scene_loader.get_scene_from_token(token)
agent_input = scene.get_agent_input()
metadata = scene.scene_metadata
token_path = self._cache_path / metadata.log_name / metadata.initial_token
os.makedirs(token_path, exist_ok=True)
for builder in self._feature_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
data_dict = builder.compute_features(agent_input)
dump_feature_target_to_pickle(data_dict_path, data_dict)
for builder in self._target_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
data_dict = builder.compute_targets(scene)
dump_feature_target_to_pickle(data_dict_path, data_dict)
self._valid_cache_paths[token] = token_path
def _load_scene_with_token(self, token: str) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Helper function to load feature / targets from cache.
:param token: unique identifier of scene to load
:return: tuple of feature and target dictionaries
"""
token_path = self._valid_cache_paths[token]
features: Dict[str, torch.Tensor] = {}
for builder in self._feature_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
data_dict = load_feature_target_from_pickle(data_dict_path)
features.update(data_dict)
targets: Dict[str, torch.Tensor] = {}
for builder in self._target_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
data_dict = load_feature_target_from_pickle(data_dict_path)
targets.update(data_dict)
return (features, targets)
def cache_dataset(self) -> None:
"""Caches complete dataset into cache folder."""
assert self._cache_path is not None, "Dataset did not receive a cache path!"
os.makedirs(self._cache_path, exist_ok=True)
# determine tokens to cache
if self._force_cache_computation:
tokens_to_cache = self._scene_loader.tokens
else:
tokens_to_cache = set(self._scene_loader.tokens) - set(self._valid_cache_paths.keys())
tokens_to_cache = list(tokens_to_cache)
logger.info(
f"""
Starting caching of {len(tokens_to_cache)} tokens.
Note: Caching tokens within the training loader is slow. Only use it with a small number of tokens.
You can cache large numbers of tokens using the `run_dataset_caching.py` python script.
"""
)
for token in tqdm(tokens_to_cache, desc="Caching Dataset"):
self._cache_scene_with_token(token)
def __len__(self) -> None:
"""
:return: number of samples to load
"""
return len(self._scene_loader)
def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], str]:
"""
Get features or targets either from cache or computed on-the-fly.
:param idx: index of sample to load.
:return: tuple of feature and target dictionary
"""
token = self._scene_loader.tokens[idx]
features: Dict[str, torch.Tensor] = {}
targets: Dict[str, torch.Tensor] = {}
if self.agent_input_only:
agent_input = AgentInput.from_scene_dict_list(
self._scene_loader.scene_frames_dicts[token],
self._scene_loader._sensor_blobs_path,
num_history_frames=self._scene_loader._scene_filter.num_history_frames,
sensor_config=self._scene_loader._sensor_config,
)
for builder in self._feature_builders:
features.update(builder.compute_features(agent_input))
return features, {'dummy': torch.zeros(1)}, token
if self._cache_path is not None:
assert (
token in self._valid_cache_paths.keys()
), f"The token {token} has not been cached yet, please call cache_dataset first!"
features, targets = self._load_scene_with_token(token)
else:
scene = self._scene_loader.get_scene_from_token(self._scene_loader.tokens[idx])
agent_input = scene.get_agent_input()
for builder in self._feature_builders:
features.update(builder.compute_features(agent_input))
if self.is_training:
for builder in self._target_builders:
targets.update(builder.compute_targets(scene))
return (features, targets, token)