BasicTS icon indicating copy to clipboard operation
BasicTS copied to clipboard

一些问题和优化建议

Open TensorPulse opened this issue 6 months ago • 30 comments

您好,作者,感谢提供如此完整的学习框架!本人在使用和移植基线的过程中遇到一些问题和不便的地方,在此提出来以便您参考优化。 声明:以下问题和建议仅代表个人看法,仅供参考 问题:利用pycham直接运行data_preparation显示找不到数据集文件,运行train时也一样,做如下修改就可以运行: OUTPUT_DIR = "../../../experiments/datasets/" + DATASET_NAME DATA_FILE_PATH = "../../../datasets/raw_data/{0}/{0}.npz".format(DATASET_NAME) GRAPH_FILE_PATH = "../../../datasets/raw_data/{0}/adj_{0}".format(DATASET_NAME) DISTANCE_FILE_PATH = "../../../datasets/raw_data/{0}/distance_{0}".format(DATASET_NAME) 优化建议: 1.数据集的归一化和反归一化:CFG.RESCALE:如果为True,表示既反归一化数据又将整个数据的标准化,如果为False,表示既不反归一化数据又将数据的每个通道标准化。可以拆解为两个变量,一个变量控制数据的标准化,一个变量控制数据和归一化和反归一化。 2. 模型训练结果表示不清:用模型名+epochs的方式所表达的直接信息不全,可做如下修改: CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( "checkpoints", CFG.MODEL.NAME, "_".join([CFG.DATASET_NAME, str(CFG.TRAIN.NUM_EPOCHS)]) ) 3.项目的可视化接口不足:可在tensorboard中增加一些指标或增加预测数据保存的接口 4.项目cfg文件中有许多隐藏接口,可以添加一个Simple_CFG将所有接口表达出来,例如: CFG.MODEL.SETUP_GRAPH = False CFG.TRAIN.FINETUNE_FROM CFG.RESCALE = True 5.在基线STGODE中,需要引入A_sp_hat, A_se_hat两个张量,发现即使将整个模型放入gpu中,这两个张量仍然存在于cpu中,直到后续.to(x.device)。这在模型的移植中不太便利,需要找到张量最终使用的地方。建议使用 from easytorch.device import get_device_type if get_device_type() == 'gpu': device = 'cuda' else: device = 'cpu' self.device = device 或者from easytorch.device import to_device
6.在test过程中没有进度条显示,可修改:

tqdm process bar

data_iter = tqdm(self.test_data_loader)

test loop

for iter_index, data in enumerate(data_iter):

TensorPulse avatar Aug 22 '24 07:08 TensorPulse