mcts-stl-planning
mcts-stl-planning copied to clipboard
Online Signal Temporal Logic (STL) Monte-Carlo Tree Search for Guided Imitation Learning
Follow The Rules: Online Signal Temporal Logic Tree Search for Guided Imitation Learning in Stochastic Domains
This repository contains the code for the paper submitted to ICRA 2023.
Jasmine Jerry Aloor*, Jay Patrikar*, Parv Kapoor, Jean Oh and Sebastian Scherer.
*equal contribution
Brief Overview [Video]

Seamlessly integrating rules in Learning-from-Demonstrations (LfD) policies is a critical requirement to enable the real-world deployment of AI agents. Recently Signal Temporal Logic (STL) has been shown to be an effective language for encoding rules as spatio-temporal constraints. This work uses Monte Carlo Tree Search (MCTS) as a means of integrating STL specification into a vanilla LfD policy to improve constraint satisfaction. We propose augmenting the MCTS heuristic with STL robustness values to bias the tree search towards branches with higher constraint satisfaction. While the domain-independent method can be applied to integrate STL rules online into any pre-trained LfD algorithm, we choose goal-conditioned Generative Adversarial Imitation Learning as the offline LfD policy. We apply the proposed method to the domain of planning trajectories for General Aviation aircraft around a non-towered airfield. Results using the simulator trained on real-world data showcase 60% improved performance over baseline LfD methods that do not use STL heuristics.
Installation
Environment Setup
First, we'll create a conda environment to hold the dependencies.
conda create --name stlmcts --file requirements.txt
conda activate stlmcts
Data Setup
The network uses TrajAir Dataset.
cd dataset
wget https://kilthub.cmu.edu/ndownloader/articles/14866251/versions/1
Unzip dataset files in place as shown in the folder tree below.
mcts-stl-planning/
├─ dataset/
│ ├─ 111_days/
│ │ ├─ processed_data/
│ │ │ ├─ test/
│ │ │ ├─ train/
├─ episodes/
├─ gym/
├─ images/
├─ mcts/
├─ model/
├─ rtamt/
├─ saved_models/
├─ costmap.py
├─ play.py
MCTS Parameters
The MCTS is implemented as a recursive function where each iteration ends with a new leaf that corresponds to an action in the trajectory library. For running the algorithm, we can choose any dataset. For example, to test with Behavior Cloning algorithm (BC) with 111_days use:
python play.py --dataset_name 111_days --algo BC --tcn_channel_size 256
To test with goalGAIL algorithm (goalGAIL) with 111_days use:
python play.py --dataset_name 111_days --algo GAIL --tcn_channel_size 512
-
--checkpointArgument to set checkpoint for MCTS (default =/episodes/) -
--load_episodesArgument to load episodes (default =False) -
--algoBaseline algorithm for network (default =BC) -
--numMCTSNumber of MCTS Trees (default =50) -
--cpuctArgument to balance the exploration and exploitation (default =1) -
--huctWeight of heuristic (default =4000) -
--parallelArgument to set parallel execution (default =False) -
--num_processNumber of processes (default =1000) -
--algoBaseline algorithm for network (default =BC) -
--numEpisodeStepsNumber of steps in an episode (default =30) -
--maxlenOfQueueMaximim length of queue (default =25600) -
--numEpsMaximum number of episodes (default =100) -
--numEpsTestMaximum number of episodes during testing (default =100) -
--numItersNumber of iterations (default =1) !!review -
--plotTo plot the trees (default =False)
Inititalization Network
Model Training
For training data we can choose between the 4 training subsets of data labelled 7days1, 7days1, 7days1, 7days1 or the entire dataset 111_days. For example, to train with 7days1 use:
python train.py --dataset_name 7days1
Training will use GPUs if available.
Optional arguments can be given as following:
-
--dataset_foldersets the working directory for data. Default is current working directory (default =/dataset/). -
--dataset_namesets the data block to use (default =111_days). -
--models_foldersets the directory for saved model. Default is saved_models directory (default =/saved_models/). -
--model_weightssets the model weight to be used (default =model_111_days_4.pt). !! Needs review -
--obsobservation length (default =11). -
--predsprediction length (default =120). -
--preds_stepprediction steps (default =5). -
--delimDelimiter used in data (default =). -
--use_trajairOption to use trajairnet model(default =False). !! needs review -
--algoBaseline algorithm for network (default =BC) -
--total_epochsTotal number passes over the entire training data set (default =10).
--model_pthPath to save the models (default =/saved_models/).
TCN Network Arguments
--input_channelsThe number of input channels (x,y,z) (default =3).--tcn_kernelsThe size of the kernel to use in each convolutional layer (default =4).--tcn_channel_sizeThe number of hidden units to use (default =512).--tcn_layersThe number of layers to use. (default =2)--mlp_layerThe number of hidden units in the MLP decoder (default, BC =91, goalGAIL = ).
Model Testing
--dataset_foldersets the working directory for data. Default is current working directory (default =/dataset/).--dataset_namesets the data block to use (default =7days1).--obsobservation length (default =11).--predsprediction length (default =120).--preds_stepprediction steps (default =10).--delimDelimiter used in data (default =).--model_dirPath to load the models (default =/saved_models/).--epochEpoch to load the model.
STL Library: RTAMT
To install the RTAMT library for monitoring of Signal Temporal Logic (STL) rtamt follow the package's installation procedure
Additionally, if the antlr4 dependency throws an error, follow the conda installation here
TrajAir Dataset
More information about TrajAir dataset is avaiable at link.
Cite
If you have any questions, please contact [email protected] or open an issue on this repo.
If you find this repository useful for your research, please cite the following paper:
@misc{https://doi.org/10.48550/arxiv.2209.13737,
doi = {10.48550/ARXIV.2209.13737},
url = {https://arxiv.org/abs/2209.13737},
author = {Aloor, Jasmine Jerry and Patrikar, Jay and Kapoor, Parv and Oh, Jean and Scherer, Sebastian},
keywords = {Robotics (cs.RO), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Follow The Rules: Online Signal Temporal Logic Tree Search for Guided Imitation Learning in Stochastic Domains},
publisher = {arXiv},
year = {2022},
copyright = {arXiv.org perpetual, non-exclusive license}
}