dfdx
dfdx copied to clipboard
How does one update one model from another model?
I have two models, identical in structure, in which one is meant to be updated periodically from the other. Prior to #854 I did so using TensorCollection::iter_tensors
as follows::
fn update_model(&mut self) {
struct Updater;
impl TensorVisitor<f32, Cpu> for Updater {
type Viewer = (ViewTensorRef, ViewTensorRef);
type Err = <Cpu as HasErr>::Err;
type E2 = f32;
type D2 = Cpu;
fn visit<S: dfdx::shapes::Shape>(
&mut self,
_: TensorOptions<S, f32, Cuda>,
(model, model_training): <Self::Viewer as TensorViewer>::View<
'_,
Tensor<S, f32, Cuda>,
>,
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err> {
let mut model = model.clone();
model.axpy(1.0 - TAU, model_training, TAU);
Ok(Some(model))
}
}
self.model = TensorCollection::iter_tensors(&mut RecursiveWalker {
m: (&self.model, &self.model_training),
f: &mut Updater,
})
.unwrap()
.unwrap();
}
This functionality was removed in #854, and it's not clear to me what functionality replaces it.
How does one update a model now?