spatial-temporal-transformer
spatial-temporal-transformer copied to clipboard
traffic flow prediction
Overview
The pytorch implementation of paper: Spatial-temporal Transformer Network with Self-supervised Learning for Traffic Flow Prediction.
Quick start
After download the repo and necessary dataset, you need to first generate training data and then lunach a training.
- clone this repo
git clone https://github.com/pengzhangzhi/spatial-temporal-transformer
- install packgaes
pip install -r requirements.txt
- Download TaxiBJ dataset and put it in the path
spatial-temporal-transformer/data/TaxiBJ/. You only need to downloadBJ16_M32x32_T30_InOut.h5, the rest raw files are aready in the folderspatial-temporal-transformer/data/TaxiBJ/. - TaxiNYC dataset is already in the repo, you do not need to download it.
- generate training data
- generate TaxiNYC training data:
python prepareDataNY.py -c TaxiNYC.json. - generate TaxiBJ training data:
python prepareDataNY.py -c TaxiBJ.json. - NOTE: make sure the config file
TaxiBJ.jsonare in theconfigfolder.
- generate TaxiNYC training data:
- lunach training:
python train.py -c TaxiBJ.json, to train the model followed the hyper-parameters in theTaxiBJ.jsonfile.
Citation
If you would like to use the code please cite my paper.