torchmd-net icon indicating copy to clipboard operation
torchmd-net copied to clipboard

Train only an output model, freezing the representation model.

Open RaulPPelaez opened this issue 10 months ago • 1 comments

Adds the --freeze-representation, --reset-output-model and --overwrite-representation to train.py.

  • Freeze representation: Makes it so that the representation model weights are not trained
  • Reset output model: Makes it so that the reset_parameters() is called on the output model after loading it for training. Ignored if load-model is not used.
  • Overwrite representation: Takes a path to a checkpoint, if present the weights of the representation model will be taken from here as initial weights.

This allows to train many output modules while keeping a single representation model. The workflow is intended to work like this:

 $ torchmd-train --conf my_model1.yaml --log-dir model1 # Initial training for the representation model
  # Train the second model but load the representation weighs from the first one.
  # Note that there are no limitations on the output model here with respect to model1.
 $ torchmd-train --conf my_model2.yaml --log-dir model2 --freeze-representation --overwrite-representation model1/best.ckpt
 # Now you have two models that share the representation model

For inference we can take advantage of the shared representation model and trick torch into calling it only one time. For this we can create a class similar to Ensemble. For prototyping we can simply make it similar to TorchMD_Net like:


    def forward(
        self,
        z: Tensor,
        pos: Tensor,
        batch: Optional[Tensor] = None,
        box: Optional[Tensor] = None,
        q: Optional[Tensor] = None,
        s: Optional[Tensor] = None,
        extra_args: Optional[Dict[str, Tensor]] = None,
    ) -> Tuple[Tensor, Tensor]:
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        if self.derivative:
            pos.requires_grad_(True)
        x, v, z, pos, batch = self.models[0].representation_model(
            z, pos, batch, box=box, q=q, s=s
        )
        y = []
        neg_dy = []
        for m in self.models:
           o = m.output_model
           x_o = o.pre_reduce(x,v,z,pos,batch)
           if self.prior_model is not None:
               for prior in self.prior_model:
                   x_o = prior.pre_reduce(x_o, z, pos, batch, extra_args)
            x_o = o.reduce(x_o, batch)
            y_o = o.post_reduce(x_o)
            if self.prior_model is not None:
                for prior in self.prior_model:
                    y_o = prior.post_reduce(y_o, z, pos, batch, box, extra_args)
            y.append(y_o)
            if self.derivative:
                grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y_o)]
                dy_o = grad(
                    [y_o],
                    [pos],
                    grad_outputs=grad_outputs,
                    create_graph=self.training,
                    retain_graph=self.training,
                )[0]
                assert dy_o is not None, "Autograd returned None for the force prediction."
                neg_dy.append(-dy_o)
        y = torch.stack(y)
        neg_dy = torch.stack(neg_dy) if self.derivative else torch.empty(0)
        y_mean = torch.mean(y, axis=0)
        neg_dy_mean = torch.mean(neg_dy, axis=0)  if self.derivative else torch.empty(0)
        y_std = torch.std(y, axis=0)
        neg_dy_std = torch.std(neg_dy, axis=0)  if self.derivative else torch.empty(0)

        if self.return_std:
            return y_mean, neg_dy_mean, y_std, neg_dy_std
        else:
            return y_mean, neg_dy_mean

RaulPPelaez avatar Apr 17 '24 09:04 RaulPPelaez

cc @stefdoerr

RaulPPelaez avatar Apr 17 '24 09:04 RaulPPelaez