etna icon indicating copy to clipboard operation
etna copied to clipboard

Fix `SaveNNMixin` to work on torch-1.13 and torch-2.0

Open Mr-Geekman opened this issue 1 year ago • 0 comments

🚀 Feature Request

It seems like the only thing that stops us from updating torch is saving on disk. We should fix this by changing SaveNNMixin.

Using cloudpickle fixes the issue only for torch-1.13, but it doesn't work on torch-2.0.

Proposal

The goal is to fix SaveNNMixin to work with new torch versions.

We should try to use state_dict for saving, not saving the whole object with pickle/dill/cloudpickle. Probably, it will be useful to separate state of the object into torch nn and other components. Class etna.pipeline.mixins.SaveModelPipelineMixin can be useful as a reference of how it can be done.

After doing the task the requirement on torch should be weakened to include 1.13 and 2.0 versions.

Test cases

  • Fix existing tests
  • Check manually that on torch-1.13 and torch-2.0 tests.test_models.nn are working

Additional context

No response

Mr-Geekman avatar May 25 '23 14:05 Mr-Geekman