A Unified Framework for Cross-Dataset Generalization of Vehicle Trajectory Prediction
UniTraj is a framework for vehicle trajectory prediction, designed by researchers from VITA lab at EPFL. It provides a unified interface for training and evaluating different models on multiple dataset, and supports easy configuration and logging. Powered by Hydra, Pytorch-lightinig, and WandB, the framework is easy to configure, train and logging. In this project, you will be using UniTraj to train and evalulate a model we call PTR (predictive transformer) on the data we have given to you.
First start by cloning the repository:
git clone https://github.com/vita-epfl/unitraj-DLAV.git
cd unitraj-DLAV
Then make a virtual environment and install the required packages.
python3 -m venv venv
source venv/bin/activate
# Install MetaDrive Simulator
cd ~/ # Go to the folder you want to host these two repos.
git clone https://github.com/metadriverse/metadrive.git
cd metadrive
pip install -e .
# Install ScenarioNet
cd ~/ # Go to the folder you want to host these two repos.
git clone https://github.com/metadriverse/scenarionet.git
cd scenarionet
pip install -e .
Finally, install Unitraj and login to wandb via:
cd unitraj-DLAV # Go to the folder you cloned the repo
pip install -r requirements.txt
pip install -e .
wandb login
If you don't have a wandb account, you can create one here. It is a free service for open-source projects and you can use it to log your experiments and compare different models easily.
You can verify the installation of UniTraj via running the training script:
python train.py method=ptr
The incomplete PTR model will be trained on several samples of data available in motionnet/data_samples
.
There are three main components in UniTraj: dataset, model and config. The structure of the code is as follows:
motionnet
├── configs
│ ├── config.yaml
│ ├── method
│ │ ├── ptr.yaml
├── datasets
│ ├── base_dataset.py
│ ├── ptr_dataset.py
├── models
│ ├── ptr
│ ├── base_model
├── utils
There is a base config, dataset and model class, and each model has its own config, dataset and model class that inherit from the base class.
You can access the data here. For easier use on SCITAS, we have also provided the dataset in scitas on /work/vita/datasets/DLAV_unitraj
. We have provided a train and validation set, as well as three testing sets of different difficulty levels: easy, medium and hard.
You will be evaluating your model on the easy test set for the first milestone, and the medium and hard test sets for the second and third milestones, respectively.
Don't forget to put the path to the real data in the config file.
Your task is to complete the PTR model and train it on the data we have provided.
The model is a transformer-based model that takes the past trajectory of the vehicle and its surrounding agents, along with the map, and predicts the future trajectory.
This is the architecture of the encoder part of model (where you need to implement). Supposing we are given the past t time steps for M agents and we have a feature vector of size
- Add positional encoding to the input features at the time step dimension for distinguish between different time steps.
- Perform the temporal attention to capture the dependencies between the trajectories of each agent separately.
- Perform the spatial attention to capture the dependencies between the different agents at the same time step. These steps are repeated L times to capture the dependencies between the agents and the time steps.
The model is implemented in motionnet/models/ptr/ptr_model.py
and the config is in motionnet/configs/method/ptr.yaml
.
Take a look at the model and the config to understand the structure of the model and the hyperparameters.
You are asked to complete three parts of the model in motionnet/models/ptr/ptr_model.py
:
- The
temporal_attn_fn
function that computes the attention between the past trajectory and the future trajectory. - The
spatial_attn_fn
function that computes the attention between different agents at the same time step. - The encoder part of the model in the
_forward
function.
You can find the instructions and some hints in the file itself.
You could follow the steps in the easy kaggle competition to submit your results and compare them with the other students in the leaderboard. Here are the medium and hard competitions for the second and third milestones, respectively. We have developed a submission script for your convenience. You can run the following command to generate the submission file:
python generate_predictions.py method=ptr
Before running the above command however, you need to put the path to the checkpoint of your trained model on the config file under ckpt_path
. You can find the checkpoint of your trained model in the lightning_logs
directory in the root directory of the project.
For example, if you have trained your model for 10 epochs, you will find the checkpoint in lightning_logs/version_0/checkpoints/epoch=10-val/brier_fde=30.93.ckpt
. You need to put the path to this file in the config file.
Additionally, for the val_data_path
in the config file, you need to put the path to the test data you want to evaluate your model on. For the easy milestone, you can put the path to the easy test data, and for the second and third milestones, you can put the path to the medium and hard test data, respectively.
The script will generate a file called submission.csv
in the root directory of the project. You can submit this file to the kaggle competition. As this file could be big, we suggest you to compress it before submitting it.
In addition, the script will make some visualizations of the predictions and save them in the visualizations
directory in the root directory of the project. You can take a look at these visualizations to understand how your model is performing.
It's also needed for the report of the first milestone, so don't forget to include some of them in your report.