-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HighD dataset #37
base: master
Are you sure you want to change the base?
HighD dataset #37
Changes from all commits
cc3c36a
1175009
70b49f9
d5cb7d3
3778999
cfc0936
e6cbd12
40198d9
655248c
a0074ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,106 @@ | ||
import argparse | ||
import os | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-map', type=str, default='i80', choices={'ai', 'i80', 'us101', 'lanker', 'peach'}) | ||
parser.add_argument('-map', type=str, default='i80', choices={'ai', 'i80', 'us101', 'lanker', 'peach', 'highD'}) | ||
opt = parser.parse_args() | ||
|
||
path = './traffic-data/xy-trajectories/{}/'.format(opt.map) | ||
trajectories_path = './traffic-data/state-action-cost/data_{}_v0'.format(opt.map) | ||
time_slots = [d[0].split("/")[-1] for d in os.walk(trajectories_path) if d[0] != trajectories_path] | ||
path = f'./traffic-data/xy-trajectories/{opt.map}/' | ||
trajectories_path = f'./traffic-data/state-action-cost/data_{opt.map}_v0' | ||
_, time_slots, _ = next(os.walk(trajectories_path)) | ||
|
||
df = dict() | ||
for ts in time_slots: | ||
df[ts] = pd.read_table(path + ts + '.txt', sep='\s+', header=None, names=( | ||
'Vehicle ID', | ||
'Frame ID', | ||
'Total Frames', | ||
'Global Time', | ||
'Local X', | ||
'Local Y', | ||
'Global X', | ||
'Global Y', | ||
'Vehicle Length', | ||
'Vehicle Width', | ||
'Vehicle Class', | ||
'Vehicle Velocity', | ||
'Vehicle Acceleration', | ||
'Lane Identification', | ||
'Preceding Vehicle', | ||
'Following Vehicle', | ||
'Spacing', | ||
'Headway' | ||
)) | ||
if opt.map == 'highD': | ||
# Track dataframes | ||
dtypes_dict = { | ||
'Frame ID': np.int64, | ||
'Vehicle ID': np.int64, | ||
'Local X': np.float64, | ||
'Local Y': np.float64, | ||
'Vehicle Length': np.float64, | ||
'Vehicle Width': np.float64, | ||
'Vehicle Velocity X': np.float64, | ||
'Vehicle Velocity Y': np.float64, | ||
'Vehicle Acceleration X': np.float64, | ||
'Vehicle Acceleration Y': np.float64, | ||
'Front Sight Distance': np.float64, | ||
'Back Sight Distance': np.float64, | ||
'Spacing': np.float64, | ||
'Headway': np.float64, | ||
'Time to Collision': np.float64, | ||
'Preceding Velocity X': np.float64, | ||
'Preceding Vehicle': np.int64, | ||
'Following Vehicle': np.int64, | ||
'Left Preceding ID': np.int64, | ||
'Left Alongside ID': np.int64, | ||
'Left Following ID': np.int64, | ||
'Right Preceding ID': np.int64, | ||
'Right Alongside ID': np.int64, | ||
'Right Following ID': np.int64, | ||
'Lane Identification': np.int64, | ||
} | ||
for ts in time_slots: | ||
df[ts] = pd.read_csv(os.path.join(path, f'{ts}_tracks.csv'), | ||
header=0, | ||
names=( | ||
'Frame ID', | ||
'Vehicle ID', | ||
'Local X', | ||
'Local Y', | ||
'Vehicle Length', | ||
'Vehicle Width', | ||
'Vehicle Velocity', | ||
'Vehicle Velocity Y', | ||
'Vehicle Acceleration', | ||
'Vehicle Acceleration Y', | ||
'Front Sight Distance', | ||
'Back Sight Distance', | ||
'Spacing', | ||
'Headway', | ||
'Time to Collision', | ||
'Preceding Velocity X', | ||
'Preceding Vehicle', | ||
'Following Vehicle', | ||
'Left Preceding ID', | ||
'Left Alongside ID', | ||
'Left Following ID', | ||
'Right Preceding ID', | ||
'Right Alongside ID', | ||
'Right Following ID', | ||
'Lane Identification' | ||
), | ||
dtype=dtypes_dict) | ||
else: | ||
for ts in time_slots: | ||
df[ts] = pd.read_table(path + ts + '.txt', sep='\s+', header=None, names=( | ||
'Vehicle ID', | ||
'Frame ID', | ||
'Total Frames', | ||
'Global Time', | ||
'Local X', | ||
'Local Y', | ||
'Global X', | ||
'Global Y', | ||
'Vehicle Length', | ||
'Vehicle Width', | ||
'Vehicle Class', | ||
'Vehicle Velocity', | ||
'Vehicle Acceleration', | ||
'Lane Identification', | ||
'Preceding Vehicle', | ||
'Following Vehicle', | ||
'Spacing', | ||
'Headway' | ||
)) | ||
|
||
car_sizes = dict() | ||
for ts in time_slots: | ||
d = df[ts] | ||
car = lambda i: d[d['Vehicle ID'] == i] | ||
def car(i): return d[d['Vehicle ID'] == i] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LOL There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PEP8's fault not mine... |
||
car_sizes[ts] = dict() | ||
cars = set(d['Vehicle ID']) | ||
for c in cars: | ||
|
@@ -47,4 +109,5 @@ | |
car_sizes[ts][c] = size | ||
print(c) | ||
|
||
torch.save(car_sizes, 'traffic-data/state-action-cost/data_{}_v0/car_sizes.pth'.format(opt.map)) | ||
torch.save(car_sizes, f'traffic-data/state-action-cost/data_{opt.map}_v0/car_sizes.pth') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
import argparse, pdb | ||
import gym | ||
import numpy as np | ||
import argparse | ||
import os | ||
import pickle | ||
import random | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
import scipy.misc | ||
from gym.envs.registration import register | ||
|
||
parser = argparse.ArgumentParser() | ||
|
@@ -19,8 +18,10 @@ | |
parser.add_argument('-data_dir', type=str, default='traffic-data/state-action-cost/') | ||
parser.add_argument('-fps', type=int, default=30) | ||
parser.add_argument('-time_slot', type=int, default=0) | ||
parser.add_argument('-map', type=str, default='i80', choices={'ai', 'i80', 'us101', 'lanker', 'peach'}) | ||
parser.add_argument('-map', type=str, default='i80', choices={'ai', 'i80', 'us101', 'lanker', 'peach', 'highD'}) | ||
parser.add_argument('-delta_t', type=float, default=0.1) | ||
parser.add_argument('-recording', type=str, default="01", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I think this is an artifact of me trying to keep this new script as similar to map_i80 as possible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use |
||
help='Use this argument with highD maps to choose from recordings \'01\' to \'60\'') | ||
opt = parser.parse_args() | ||
|
||
opt.state_image = (opt.state_image == 1) | ||
|
@@ -43,6 +44,10 @@ | |
delta_t=opt.delta_t, | ||
) | ||
|
||
# HighD dataset will use recordings IDs rather than time slots | ||
if opt.map == 'highD': | ||
kwargs['rec'] = opt.recording | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this would be unnecessary. |
||
|
||
register( | ||
id='Traffic-v0', | ||
entry_point='traffic_gym:Simulator', | ||
|
@@ -55,30 +60,37 @@ | |
kwargs=kwargs | ||
) | ||
|
||
gym.envs.registration.register( | ||
register( | ||
id='US-101-v0', | ||
entry_point='map_us101:US101', | ||
kwargs=kwargs, | ||
) | ||
|
||
gym.envs.registration.register( | ||
register( | ||
id='Lankershim-v0', | ||
entry_point='map_lanker:Lankershim', | ||
kwargs=kwargs, | ||
) | ||
|
||
gym.envs.registration.register( | ||
register( | ||
id='Peachtree-v0', | ||
entry_point='map_peach:Peachtree', | ||
kwargs=kwargs, | ||
) | ||
|
||
register( | ||
id='HighD-v0', | ||
entry_point='map_highD:HighD', | ||
kwargs=kwargs | ||
) | ||
|
||
env_names = { | ||
'ai': 'Traffic-v0', | ||
'i80': 'I-80-v0', | ||
'us101': 'US-101-v0', | ||
'lanker': 'Lankershim-v0', | ||
'peach': 'Peachtree-v0', | ||
'highD': 'HighD-v0', | ||
} | ||
|
||
print('Building the environment (loading data, if any)') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, I think you could have used an ordered dict, and then you could have simply dumped the keys 😜
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Do you want me to change this, or should we leave it for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to change it, I can wait before merging.