Bigscity-LibCity icon indicating copy to clipboard operation
Bigscity-LibCity copied to clipboard

How to adapt the models to grid data and the correspondence between data and models?

Open aptx1231 opened this issue 1 year ago • 3 comments

For the details, please visit https://bigscity-libcity-docs.readthedocs.io/en/latest/user_guide/data/dataset_for_task.html

Note 5. There are some models that require three inputs, i.e., CLOSENESS, PERIOD, TREND, such as STResNet, ACFM, ASTGCN. for such models we implemented corresponding generalized versions with only CLOSENESS inputs for fair comparisons, i.e., STResNetCommon, ACFMCommon, ASTGCNCommon.

aptx1231 avatar Oct 11 '23 08:10 aptx1231

Here is how to generalize models used for point-based data for grid-based data.

(1) If the dataset class used by the model is TrafficStatePointDataset, such as AGCRN, ASTGCNCommon, CCRNN, etc., you can directly set dataset_class to TrafficStateGridDataset in task_file.json or through a custom configuration file(--config_file). Then set the parameter use_row_column of TrafficStateGridDataset to False.

(2) If the dataset class used by the model is the subclass of TrafficStatePointDataset, such as ASTGCNDataset, CONVGCNDataset, STG2SeqDataset, etc., you can modify the file of the dataset class to make it inherit TrafficStateGridDataset instead of the current TrafficStatePointDataset. Then set the parameter use_row_column in the function init() to False.

Example (1):

Before modification:

task_config.json "RNN": { "dataset_class": "TrafficStatePointDataset", }, TrafficStateGridDataset.json { "use_row_column": true }

After modification:

task_config.json "RNN": { "dataset_class": "TrafficStateGridDataset", }, TrafficStateGridDataset.json { "use_row_column": false }

Example (2)::

Before modification:

task_config.json "STG2Seq": { "dataset_class": "STG2SeqDataset", }, STG2SeqDataset.json { "use_row_column": false } stg2seq_dataset.py from libcity.data.dataset import TrafficStatePointDataset class STG2SeqDataset(TrafficStatePointDataset): def init(self, config): super().init(config) pass

After modification:

task_config.json "STG2Seq": { "dataset_class": "STG2SeqDataset", }, STG2SeqDataset.json { "use_row_column": false } stg2seq_dataset.py from libcity.data.dataset import TrafficStateGridDataset class STG2SeqDataset(TrafficStateGridDataset): def init(self, config): super().init(config) self.use_row_column = False pass

aptx1231 avatar Oct 11 '23 08:10 aptx1231

image image

aptx1231 avatar Oct 11 '23 08:10 aptx1231

Additional processing may be required for some special models.

aptx1231 avatar Oct 11 '23 08:10 aptx1231