BasicTS
BasicTS copied to clipboard
一些问题和优化建议
您好,作者,感谢提供如此完整的学习框架!本人在使用和移植基线的过程中遇到一些问题和不便的地方,在此提出来以便您参考优化。
声明:以下问题和建议仅代表个人看法,仅供参考
问题:利用pycham直接运行data_preparation显示找不到数据集文件,运行train时也一样,做如下修改就可以运行:
OUTPUT_DIR = "../../../experiments/datasets/" + DATASET_NAME
DATA_FILE_PATH = "../../../datasets/raw_data/{0}/{0}.npz".format(DATASET_NAME)
GRAPH_FILE_PATH = "../../../datasets/raw_data/{0}/adj_{0}".format(DATASET_NAME)
DISTANCE_FILE_PATH = "../../../datasets/raw_data/{0}/distance_{0}".format(DATASET_NAME)
优化建议:
1.数据集的归一化和反归一化:CFG.RESCALE:如果为True,表示既反归一化数据又将整个数据的标准化,如果为False,表示既不反归一化数据又将数据的每个通道标准化。可以拆解为两个变量,一个变量控制数据的标准化,一个变量控制数据和归一化和反归一化。
2. 模型训练结果表示不清:用模型名+epochs的方式所表达的直接信息不全,可做如下修改:
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
"checkpoints",
CFG.MODEL.NAME,
"_".join([CFG.DATASET_NAME, str(CFG.TRAIN.NUM_EPOCHS)])
)
3.项目的可视化接口不足:可在tensorboard中增加一些指标或增加预测数据保存的接口
4.项目cfg文件中有许多隐藏接口,可以添加一个Simple_CFG将所有接口表达出来,例如:
CFG.MODEL.SETUP_GRAPH = False
CFG.TRAIN.FINETUNE_FROM
CFG.RESCALE = True
5.在基线STGODE中,需要引入A_sp_hat, A_se_hat两个张量,发现即使将整个模型放入gpu中,这两个张量仍然存在于cpu中,直到后续.to(x.device)。这在模型的移植中不太便利,需要找到张量最终使用的地方。建议使用
from easytorch.device import get_device_type
if get_device_type() == 'gpu':
device = 'cuda'
else:
device = 'cpu'
self.device = device
或者from easytorch.device import to_device
6.在test过程中没有进度条显示,可修改:
tqdm process bar
data_iter = tqdm(self.test_data_loader)
test loop
for iter_index, data in enumerate(data_iter):