dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

How does one update one model from another model?

Open hovinen opened this issue 1 year ago • 0 comments

I have two models, identical in structure, in which one is meant to be updated periodically from the other. Prior to #854 I did so using TensorCollection::iter_tensors as follows::

    fn update_model(&mut self) {
        struct Updater;
        impl TensorVisitor<f32, Cpu> for Updater {
            type Viewer = (ViewTensorRef, ViewTensorRef);
            type Err = <Cpu as HasErr>::Err;
            type E2 = f32;
            type D2 = Cpu;

            fn visit<S: dfdx::shapes::Shape>(
                &mut self,
                _: TensorOptions<S, f32, Cuda>,
                (model, model_training): <Self::Viewer as TensorViewer>::View<
                    '_,
                    Tensor<S, f32, Cuda>,
                >,
            ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err> {
                let mut model = model.clone();
                model.axpy(1.0 - TAU, model_training, TAU);
                Ok(Some(model))
            }
        }
        self.model = TensorCollection::iter_tensors(&mut RecursiveWalker {
            m: (&self.model, &self.model_training),
            f: &mut Updater,
        })
        .unwrap()
        .unwrap();
    }

This functionality was removed in #854, and it's not clear to me what functionality replaces it.

How does one update a model now?

hovinen avatar Jan 05 '24 16:01 hovinen