tch-rs
tch-rs copied to clipboard
How to clone a model in rust?
in c++
auto model = torch::jit::load("your_model.pt");
input = input.clone();
How to do it in rust?
I think you could do the same as with any non-cloneable object in rust, i.e. add a Rc
or Arc
layer for ref counting - nothing specific to tch
. Maybe I'm missing something here?
let model = tch::CModule::load_on_device(model_path, device).unwrap(); let model: Rc<&tch::CModule> = Rc::new(&model); let model2 = model.clone();
Rc will only clone a pointer to the same allocation.
My problem is that due to unknown reason, whenever I called model.forward() repeatedly using the same inputs, it will produce different results. Therefore, I need a fast way to clone the model before each forward() call. In c++, model.clone() will copy the content. How about in rust?
Below is extract from cloneable.h in "torch/csrc/api/include/nn/cloneable.h" But I don't know how to translate it to rust.
/// Performs a recursive "deep copy" of the
Module`, such that all parameters
/// and submodules in the cloned module are different from those in the
/// original module.
std::shared_ptr<Module> clone(
const optional<Device>& device = nullopt) const override {
NoGradGuard no_grad;
const auto& self = static_cast<const Derived&>(*this);
auto copy = std::make_shared<Derived>(self);
copy->parameters_.clear();
copy->buffers_.clear();
copy->children_.clear();
copy->reset();
TORCH_CHECK(
copy->parameters_.size() == parameters_.size(),
"The cloned module does not have the same number of "
"parameters as the original module after calling reset(). "
"Are you sure you called register_parameter() inside reset() "
"and not the constructor?");
for (const auto& parameter : named_parameters(/*recurse=*/false)) {
auto& tensor = *parameter;
auto data = device && tensor.device() != *device
? tensor.to(*device)
: autograd::Variable(tensor).clone();
copy->parameters_[parameter.key()].set_data(data);
}
TORCH_CHECK(
copy->buffers_.size() == buffers_.size(),
"The cloned module does not have the same number of "
"buffers as the original module after calling reset(). "
"Are you sure you called register_buffer() inside reset() "
"and not the constructor?");
for (const auto& buffer : named_buffers(/*recurse=*/false)) {
auto& tensor = *buffer;
auto data = device && tensor.device() != *device
? tensor.to(*device)
: autograd::Variable(tensor).clone();
copy->buffers_[buffer.key()].set_data(data);
}
TORCH_CHECK(
copy->children_.size() == children_.size(),
"The cloned module does not have the same number of "
"child modules as the original module after calling reset(). "
"Are you sure you called register_module() inside reset() "
"and not the constructor?");
for (const auto& child : children_) {
copy->children_[child.key()]->clone_(*child.value(), device);
}
return copy;
}`
Below is extract from cloneable.h in "torch/csrc/api/include/nn/cloneable.h" But I don't know how to translate it to rust.
/// Performs a recursive "deep copy" of the
Module`, such that all parameters /// and submodules in the cloned module are different from those in the /// original module.std::shared_ptr<Module> clone( const optional<Device>& device = nullopt) const override { NoGradGuard no_grad; const auto& self = static_cast<const Derived&>(*this); auto copy = std::make_shared<Derived>(self); copy->parameters_.clear(); copy->buffers_.clear(); copy->children_.clear(); copy->reset(); TORCH_CHECK( copy->parameters_.size() == parameters_.size(), "The cloned module does not have the same number of " "parameters as the original module after calling reset(). " "Are you sure you called register_parameter() inside reset() " "and not the constructor?"); for (const auto& parameter : named_parameters(/*recurse=*/false)) { auto& tensor = *parameter; auto data = device && tensor.device() != *device ? tensor.to(*device) : autograd::Variable(tensor).clone(); copy->parameters_[parameter.key()].set_data(data); } TORCH_CHECK( copy->buffers_.size() == buffers_.size(), "The cloned module does not have the same number of " "buffers as the original module after calling reset(). " "Are you sure you called register_buffer() inside reset() " "and not the constructor?"); for (const auto& buffer : named_buffers(/*recurse=*/false)) { auto& tensor = *buffer; auto data = device && tensor.device() != *device ? tensor.to(*device) : autograd::Variable(tensor).clone(); copy->buffers_[buffer.key()].set_data(data); } TORCH_CHECK( copy->children_.size() == children_.size(), "The cloned module does not have the same number of " "child modules as the original module after calling reset(). " "Are you sure you called register_module() inside reset() " "and not the constructor?"); for (const auto& child : children_) { copy->children_[child.key()]->clone_(*child.value(), device); } return copy;
}`
I believe you can do this:
fn clone_model(
model_var_store: &VarStore,
) -> Result<(VarStore, impl nn::Module), TchError> {
let mut cloned_var_store = VarStore::new(Device::cuda_if_available());
cloned_var_store.copy(model_var_store)?;
let net = nn::linear(cloned_var_store.root(), 1337, 1337, Default::default());
Ok((cloned_var_store, net))
}
In my use case, i need to clone a best trained model to a buffer model and save the buffer to disk when conducting early stopping. For correct clone and save, i need to initiate models for BOTH best model and buffer model:
This works:
// Initiate models
let vs = nn::VarStore::new(Device::Cpu); // model for training
let net = Net(&vs.root());
let mut buffer_vs = nn::VarStore::new(Device::Cpu); // buffer model to clone vs model and save to disk
let buffer_net = Net(&buffer_vs.root()); // This line is important to ensure correct saving, otherwise buffer_vs would be empty even without panics after copying
// Do some training here on net...
// Clone and save to disk
buffer_vs.copy(&vs).expect("Failed to copy);
buffer_vs.save("path/to/save").expect("Failed to save);
I guess this is because .copy
method can only copy stuff from a known structure to another, and initiating model with varstore helps to handle the structure to be the same.