gst
gst copied to clipboard
[RA-L + ICRA22] Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction
GST
This is the implementation for the paper
Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction
Zhe Huang, Ruohua Li, Kazuki Shin, Katherine Driggs-Campbell
published in RA-L.
GST is the abbreviation of our model Gumbel Social Transformer. All code was developed and tested on Ubuntu 18.04 with CUDA 10.2, Python 3.6.9, and PyTorch 1.7.1.
Citation
If you find this repo useful, please cite
@article{huang2022learning,
title={Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction},
author={Huang, Zhe and Li, Ruohua and Shin, Kazuki and Driggs-Campbell, Katherine},
journal={IEEE Robotics and Automation Letters},
year={2022},
volume={7},
number={2},
pages={1198-1205},
doi={10.1109/LRA.2021.3138547}
}
Setup
1. Create a Virtual Environment. (Optional)
virtualenv -p /usr/bin/python3 myenv
source myenv/bin/activate
2. Install Packages
You can run either
pip install -r requirements.txt
or
pip install numpy
pip install scipy
pip install matplotlib
pip install tensorboardX
pip install torch==1.7.1
If you want to use tensorboard --logdir results
to check training curves, install tensorflow
by running
pip install tensorflow
3. Create Folders and Dataset Files.
sh run/make_dirs.sh
sh run/create_datasets.sh
Training and Evaluation on Various Configurations
To train and evaluate a model with n=1, i.e., the target pedestrian pays attention to at most one partially observed pedestrian, run
sh run/train_sparse.sh
sh run/eval_sparse.sh
To train and evaluate a model with n=1 and temporal component as a temporal convolution network, run
sh run/train_sparse_tcn.sh
sh run/eval_sparse_tcn.sh
To train and evaluate a model with full connection, i.e., the target pedestrian pays attention to all partially observed pedestrians in the scene, run
sh run/train_full_connection.sh
sh run/eval_full_connection.sh
To train and evaluate a model in which the target pedestrian pays attention to all fully observed pedestrians in the scene, run
sh run/train_full_connection_fully_observed.sh
sh run/eval_full_connection_fully_observed.sh
Important Arguments for Building Customized Configurations
-
--spatial_num_heads_edges
: n, i.e., the upperbound number of pedestrians that the target pedestrian can pay attention to in the scene. When n=0, it is defined as full connection, i.e., the target pedestrian pays attention to all pedestrians in the scene. Default is 4. -
--only_observe_full_period
: The target pedestrian only pays attention to fully observed pedestrians. Default is False. -
--temporal
: Temporal component.lstm
is Masked LSTM, andtemporal_convolution_net
is temporal convolution network. Default islstm
. -
--decode_style
: Decoding style. It has to match the option--temporal
.recursive
matcheslstm
, andreadout
matchestemporal_convolution_net
. Default isrecursive
. -
--ghost
: Add a ghost pedestrian in the scene to encourage sparsity. When--spatial_num_heads_edges
is set as zero, i.e., the target pedestrian pays attention to all pedestrians in the scene,--ghost
has to be set as False. Default is False.
Credits
Part of the code is based on the following works and repos:
[1] Mohamed, Abduallah, et al. "Social-stgcnn: A social spatio-temporal graph convolutional neural network for human trajectory prediction." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020. [GitHub]
[2] Pytorch implementation of Multi-head Attention. [Modules] [Functional]
Contact
Please feel free to open an issue or send an email to [email protected].