Graph-WaveNet
Graph-WaveNet copied to clipboard
Cleanup and Model Improvements
Changes that improve performance to 3.00 - 3.02 test MAE:
- Learning Rate Decay (default =.97)
- skip connection if
cat_feat_gc=True
- default nhid=40 (instead of 32)
- default clip=3 (instead of 5)
- replace 0s with mean in
StandardScaler
Changes that don't affect performance:
- only save one checkpoint
- add early stopping
- simplification of do_graph_conv command line args
- share parser code between train.py and test.py
- misc renamings
- store metrics in pandas and write them to disk nicely
- no expid
--save
specifies a directory. - save args to
--save
for reproducibility -
train.py
skip batches where all targets are 0: they will certainly have 0 loss. - misc renamings
- check in
best_model.pth
. Dont need to retrain to inspect. - calculate all three metrics in one function instead of 3.
Metrics: baseline: 3.00 - 3.02 finetuning: 2.99-3.00
Paper has more details on experiments that didn't work: http://arxiv.org/abs/1912.07390
Dear all
Thanks for your great contribution. I am impressed by your wonderful efforts to improve Graph Wavenet. I will review the codes as soon as possible.
Bests Zonghan
Dear all,
I think you have done a very good job. To respect your contribution mostly, I would highly recommend to keep your own repository and not merge into the project. I will instead add a link referring to your repo.
Best regards, Zonghan