UniST
UniST copied to clipboard
Official implementation for "UniST: A Prompt-Empowered Universal Model for Urban Spatio-Temporal Prediction" (KDD 2024)
UniST
A pytorch implementation for the paper: UniST: A Prompt-Empowered Universal Model for Urban Spatio-Temporal Prediction.
Yuan Yuan, Jingtao Ding, Jie Feng, Depeng Jin, Yong Li
The repo currently includes code implementations for the following tasks:
Short-term Prediction: We provide all scripts for the reproduction of short-term prediction results in this repo.
Long-term Prediction: We provide all scripts for the reproduction of long-term prediction results in this repo.
Few-shot Prediction: UniST can generalize well to scenarios with limited training data, making it to be data-efficient.
Zero-shot Prediction: UniST is demonstrated to generalize well on unseen spatio-temporal scenarios, making it a nice alternative as the fundamental backbone of the foundation spatio-temporal model.
π Updates
π’: News (2024.06) Introduction of our work in ιεδ½, ζΆη©Ίζ’η΄’δΉζ , ζΆεΊδΊΊ are available.
π’: News (2024.05) UniST has been accepted to KDD 2024.
Introduction
π By capturing the underlying commonalities across multiple spatio-temporal scenarios, UniST breaks the conventional practice that train separate models for different datasets, and has demonstrated superior performance and powerful generalization capability across diverse urban scenarios.
Overall Architecture
π The training of UniST consists of two stages: (i) large-scale spatio-temporal pre-training, and (ii) spatio-temporal knowledge-guided prompt tuning.
The pseudo-code of UniST is as simple as the following:
β Foundation models for spatio-temporal prediction
Model | Data Format | Data Scalability | Few-shot | Zero-shot | Computation Cost | Memory Cost |
---|---|---|---|---|---|---|
PromptST [1] | Grid | β | β | β | Low | Low |
GPT-ST [2] | Graph | β | β | β | Low | Low |
STEP [3] | Graph | β | β | β | Low | Low |
ST-SSL [4] | Graph | β | β | β | Low | Low |
TrafficBERT [5] | Grid/Graph | β | β | β | Low | Low |
TFM [6] | Graph | β | β | β | Low | Low |
UrbanGPT [7] | Grid | β(a) | β(a) | β(a) | High | High |
STG-LLM [8] | Graph | β | β | β | High | High |
UniST | Grid/Graph | β | β | β | Low | Low |
(a). Still restricted in the same city.
[1] PromptST: Prompt-Enhanced Spatio-Temporal Multi-Attribute Prediction, CIKM 2023
[2] GPT-ST: Generative Pre-Training of Spatio-Temporal Graph Neural Networks, NIPS 2023
[3] Pre-training enhanced spatial-temporal graph neural network for multivariate time series forecasting, KDD 2022
[4] Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction, AAAI 2023
[5] TrafficBERT: Pre-trained model with large-scale data for long-range traffic flow forecasting, Expert Systems with Applications
[6] Building transportation foundation model via generative graph transformer, ITSC 2023
[7] UrbanGPT: Spatio-Temporal Large Language Models, KDD 2024
[8] How can large language models understand spatial-temporal data?, arXiv 2024
Data
We use multiple datasets to demonstrate the UniST, which span various cities and domains. To access the datasets, please refer to data readme.
βοΈ Installation
Environment
- Tested OS: Linux
- Python >= 3.9
- torch == 2.0.0
- Tensorboard
Dependencies:
- Install Pytorch with the correct CUDA version.
- Use the
pip install -r requirements.txt
command to install all of the Python modules and packages used in this project.
π Model Training
Please first navigate to the src
directory by using the cd command: cd src
Then please create a folder named experiments
to record the training process: mkdir experiments
Stage-1: Pre-training
We provide the scripts under the folder ./scripts/pretrain.sh
. You can train UniST with the Cellular dataset as the following examples:
python main.py --device_id 3 --machine machine --dataset Crowd --task short --size middle --mask_strategy_random 'batch' --lr 3e-4 --used_data 'single' --prompt_ST 0
Once your model is trained, you will find the logs recording the training process in the ./logs/
directory. The folder will be named as the Pretrain_Dataset_<dataset>_task_<task>
. In the ./experiments/Pretrain_Dataset_<dataset>_task_<task>/model_save/
, you will find the trained model named model_best.pkl
.
In our experiments, we leverage multiple datasets to enhance UniST.
If you need to use multiple datasets, please use an asterisk (*) to separate the datasets, e.g., --dataset Crowd*Cellular*TaxiNYC*TaxiBike*TrafficSH
.
Stage-2: Prompt-tuning
We provide the scripts under the folder ./scripts/prompt_tuning.sh
. You can fine-tune UniST with the Cellular dataset as the following examples:
python main.py --device_id 2 --machine machine --task short --size middle --prompt_ST 1 --pred_len 6 --his_len 6 --num_memory_spatial 512 --num_memory_temporal 512 --prompt_content 's_p_c' --dataset Crowd --lr 3e-4 --used_data 'single' --file_load_path pretrained_model_path
There are some new parameters to specify:
-
his_len
specifies the input sequence length. -
pred_len
specifies the prediction horizon. -
file_load_path
specifies the save path of the pre-trained model, the default is./experiments/Dataset_<dataset>_task_<task>/model_save/model_best.pkl
-
num_memory_spatial
andnum_memory_temporal
specify the number of embeddings in the memory pools. -
prompt_ST
specifies whether perform prompt-tuning: 0 for no prompt and 1 for prompt-tuning. -
prompt_content
specifies the type of prompt, which can be selected from ['s_p_c','s','c','p','s_c','s_p','p_c'].
Once your model is trained, you will find the logs recording the training process in the ./logs/
directory. The folder will be named as the Prompt_Dataset_<dataset>_His_<his_len>_Pred_<pred_len>
. In the ./experiments/Prompt_Dataset_<dataset>_His_<his_len>_Pred_<pred_len>/model_save/
, you will find the fine-tuned model named model_best.pkl
.
The evaluation results of the testing set can be obtained from ./experiments/Prompt_Mode_finetuning_Dataset_<dataset>_His_<his_len>_Pred_<pred_len>/result.txt
.
Model Weights
We provide downloads of model weights on xxx. Coming soon.
π Citation
If you find this repo helpful, please cite our paper.
@article{yuan2024unist,
title={UniST: A Prompt-Empowered Universal Model for Urban Spatio-Temporal Prediction},
author={Yuan, Yuan and Ding, Jingtao and Feng, Jie and Jin, Depeng and Li, Yong},
journal={arXiv preprint arXiv:2402.11838},
year={2024}
}
πβ Acknowledgement
We appreciate the following GitHub repos a lot for their valuable code and efforts.
- Spatio-temporal prediction benchmark: https://github.com/chengtan9907/OpenSTL
- Spatio-temporal data: https://github.com/aptx1231/NYC-Dataset
- MAE: https://github.com/facebookresearch/mae
- PatchTST: https://github.com/PatchTST/PatchTST
- iTransformer: https://github.com/thuml/iTransformer
π§ Contact
If you have any questions or want to use the code, feel free to contact:
- Yuan Yuan ([email protected])