torch icon indicating copy to clipboard operation
torch copied to clipboard

[feature request] Way to load state dict for corrup model (invalid external pointers)

Open sebffischer opened this issue 2 years ago • 1 comments

Hey, thank you very much for taking the time!

I have some questions regarding the serialization of torch models. In mlr3torch we have learners (R6 classes) that have a field $model in which the network (the actual network, not the module), the optimizer and the loss function are saved. This means that after saving and loading via saveRDS and readRDS the torch tensors can be broken. This is problematic for parallelization using future.apply so I have to address this somehow.

What I want to do is the following:

  1. Give the learner a method $serialize() that converts the torch tensors to raws (using the function torch:::tensor_to_raw_vector_with_class. These are then simply stored in the learner and allow to fully store the learners information in a rds file.
  2. Give the learner a method $unserialize() which converts the R raws to torch tensors and then loads it as the state dict of the network, optimizer and loss function

However I stumled across some problems:

  1. The functions that I am using are currently not exported (like tensor_to_raw_vector_with_class)
  2. One cannot simply call $load_state_dict() on a corrupt nn_module(), meaning a nn_module whose parameters and buffers are invalid pointers.

Is it possible to export those functions and provide a function to load a state dict for a nn_module with corrupt tensors / does this even make sense?

I would be willing to make a PR as well if you think this makes sense and is relatively easy to implement :)

Cheers,

Sebastian

sebffischer avatar Jul 29 '22 14:07 sebffischer

Hi @sebffischer! Thanks for reaching out!

I think it makes sense to make load_state_dict work on corrupted models. Can't remember why it doesn't work exactly.

For 1. I think it would be better to save the full state dict in a raw value, eg with:

con <- rawConnection(raw(), open = "wr")
torch_save(module$state_dict(), con)
r <- rawConnectionValue(con)

torch_load(rawConnection(r))

I think we could implement a torch_serialize function that returns the raw object directly as this code is starting to repeat in multiple projects.

Happy to review a PR!

dfalbel avatar Aug 01 '22 01:08 dfalbel