torchmd-net
torchmd-net copied to clipboard
Train only an output model, freezing the representation model.
Adds the --freeze-representation
, --reset-output-model
and --overwrite-representation
to train.py.
- Freeze representation: Makes it so that the representation model weights are not trained
- Reset output model: Makes it so that the reset_parameters() is called on the output model after loading it for training. Ignored if load-model is not used.
- Overwrite representation: Takes a path to a checkpoint, if present the weights of the representation model will be taken from here as initial weights.
This allows to train many output modules while keeping a single representation model. The workflow is intended to work like this:
$ torchmd-train --conf my_model1.yaml --log-dir model1 # Initial training for the representation model
# Train the second model but load the representation weighs from the first one.
# Note that there are no limitations on the output model here with respect to model1.
$ torchmd-train --conf my_model2.yaml --log-dir model2 --freeze-representation --overwrite-representation model1/best.ckpt
# Now you have two models that share the representation model
For inference we can take advantage of the shared representation model and trick torch into calling it only one time. For this we can create a class similar to Ensemble. For prototyping we can simply make it similar to TorchMD_Net like:
def forward(
self,
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Tensor]:
assert z.dim() == 1 and z.dtype == torch.long
batch = torch.zeros_like(z) if batch is None else batch
if self.derivative:
pos.requires_grad_(True)
x, v, z, pos, batch = self.models[0].representation_model(
z, pos, batch, box=box, q=q, s=s
)
y = []
neg_dy = []
for m in self.models:
o = m.output_model
x_o = o.pre_reduce(x,v,z,pos,batch)
if self.prior_model is not None:
for prior in self.prior_model:
x_o = prior.pre_reduce(x_o, z, pos, batch, extra_args)
x_o = o.reduce(x_o, batch)
y_o = o.post_reduce(x_o)
if self.prior_model is not None:
for prior in self.prior_model:
y_o = prior.post_reduce(y_o, z, pos, batch, box, extra_args)
y.append(y_o)
if self.derivative:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y_o)]
dy_o = grad(
[y_o],
[pos],
grad_outputs=grad_outputs,
create_graph=self.training,
retain_graph=self.training,
)[0]
assert dy_o is not None, "Autograd returned None for the force prediction."
neg_dy.append(-dy_o)
y = torch.stack(y)
neg_dy = torch.stack(neg_dy) if self.derivative else torch.empty(0)
y_mean = torch.mean(y, axis=0)
neg_dy_mean = torch.mean(neg_dy, axis=0) if self.derivative else torch.empty(0)
y_std = torch.std(y, axis=0)
neg_dy_std = torch.std(neg_dy, axis=0) if self.derivative else torch.empty(0)
if self.return_std:
return y_mean, neg_dy_mean, y_std, neg_dy_std
else:
return y_mean, neg_dy_mean
cc @stefdoerr