tch-rs
tch-rs copied to clipboard
Provide a `clone_to_device` method for `Module`
I'd like to copy a module to another device. In C++ this is the clone method with an optional device.
While not something I currently need, I think it would also make sense to implement Clone
for Module
where it does a deep copy to the same device (and the same for implementing Clone
for Tensor
, but maybe that's a separate discussion).
Anyways, for C++ it looks like they have a separate Cloneable
template class so a separate trait for clone_to_device
would be an option rather than implementing directly on Module
.
I'd be happy to try making a PR for this.
I've discovered the strategy you use in the DDPG example of constructing a new module instance then using the var store to copy the variables. That should work for me.
Having looked at the C++ implementation of Cloneable, doing something analogous would require expanding the Module
trait to include parameters
, buffers
, and children
. Alternatively it could be something implemented on a per-Module basis using Tensor::to
/ hypothetical Tensor::clone
.
Those are bigger changes than I'd want to make given the existence of the VarStore approach so I'm no longer planning on making a PR for this at the moment.
You can see my tch-tensor-like macro that derives .shallow_clone()
on compound tensor structs.
Closing this for now as it has been a while and it feels like support this via some external crate such as tch-tensor-like is probably the best.