etna
etna copied to clipboard
Fix `SaveNNMixin` to work on torch-1.13 and torch-2.0
🚀 Feature Request
It seems like the only thing that stops us from updating torch is saving on disk. We should fix this by changing SaveNNMixin
.
Using cloudpickle
fixes the issue only for torch-1.13, but it doesn't work on torch-2.0.
Proposal
The goal is to fix SaveNNMixin
to work with new torch versions.
We should try to use state_dict
for saving, not saving the whole object with pickle
/dill
/cloudpickle
. Probably, it will be useful to separate state of the object into torch nn and other components. Class etna.pipeline.mixins.SaveModelPipelineMixin
can be useful as a reference of how it can be done.
After doing the task the requirement on torch
should be weakened to include 1.13 and 2.0 versions.
Test cases
- Fix existing tests
- Check manually that on torch-1.13 and torch-2.0
tests.test_models.nn
are working
Additional context
No response