pytorch-forecasting
pytorch-forecasting copied to clipboard
Slower trainning on Linux under 0.10.1
🐛 Bug
code:https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/stallion.html When max_epochs=1 and use cpu, it costs 9.6s on win 10 and 382.79s on linux. Using AdvancedProfiler from pytorch-lightning, there is difference in SingleDeviceStrategy.validation_step between platforms
windows:
Profile stats for: [Strategy]SingleDeviceStrategy.validation_step
504226 function calls (490331 primitive calls) in 0.638 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
2 0.000 0.000 0.638 0.319 strategy.py:338(validation_step)
2 0.000 0.000 0.638 0.319 base_model.py:411(validation_step)
2 0.000 0.000 0.486 0.243 __init__.py:517(create_log)
2 0.000 0.000 0.458 0.229 base_model.py:429(create_log)
2 0.000 0.000 0.443 0.222 base_model.py:684(log_prediction)
2 0.000 0.000 0.288 0.144 __init__.py:679(plot_prediction)
2 0.000 0.000 0.159 0.079 figure.py:3189(tight_layout)
2 0.000 0.000 0.157 0.078 tight_layout.py:251(get_tight_layout_figure)
2 0.000 0.000 0.156 0.078 tight_layout.py:19(_auto_adjust_subplotpars)
2 0.000 0.000 0.155 0.078 writer.py:641(add_figure)
4 0.000 0.000 0.155 0.039 _base.py:4570(get_tightbbox)
958/14 0.004 0.000 0.152 0.011 module.py:1104(_call_impl)
...
Linux:
Profile stats for: [Strategy]SingleDeviceStrategy.validation_step
461785 function calls (447573 primitive calls) in 52.269 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
2 0.000 0.000 52.270 26.135 strategy.py:338(validation_step)
2 0.000 0.000 52.270 26.135 base_model.py:411(validation_step)
2 0.000 0.000 52.139 26.069 __init__.py:517(create_log)
2 0.000 0.000 52.119 26.059 base_model.py:429(create_log)
2 0.000 0.000 52.106 26.053 base_model.py:684(log_prediction)
2 0.000 0.000 45.007 22.504 __init__.py:679(plot_prediction)
2 0.000 0.000 44.801 22.400 base_model.py:725(plot_prediction)
2 0.000 0.000 44.781 22.391 pyplot.py:1321(subplots)
2 0.000 0.000 44.742 22.371 pyplot.py:686(figure)
2 0.000 0.000 44.742 22.371 pyplot.py:324(new_figure_manager)
2 0.000 0.000 39.713 19.857 backend_bases.py:3487(new_figure_manager)
2 0.000 0.000 39.709 19.854 _backend_tk.py:940(new_figure_manager_given_figure)
### 250 33.181 0.133 33.186 0.133 {method 'call' of '_tkinter.tkapp' objects}
...
Cost of plot_prediction on linux is more than win 10.
- PyTorch Lightning Version: 1.6.3
- PyTorch forecasting Version: 0.10.1
- PyTorch Version: 1.11.0
- Python version: 3.8.13
- OS : centos 7.6 and win 10
- How you installed PyTorch: pip