GearNet
GearNet copied to clipboard
The pre-trained GearNet-Edge model for Fold Classification
Thank you for your amazing work! I found that for the Fold Classification task, the GearNet-Edge model was implemented based on the GearNetIEConv script rather than the GearNet script, which has some detail differences (e.g., extra input embedding and ieconv layers). Based on this, I would like to ask whether you could provide the pretrained GearNet-Edge model based on multiview contrast learning and the GearNetIEConv script for Fold Classification (rather than based on GearNet script for EC task)? Thank you.
Hi, the config file for GearNet-Edge-IEConv on Fold is config/Fold3D/gearnet_edge_ieconv.yaml
. The pre-trained checkpoints of GearNet-Edge can be found at https://zenodo.org/record/7723075.
Thank you. It seems that fold_mc_gearnet_edge_ieconv.pth includes the encoder and decoder parameters after finetuning. I just would like to do some experiments on my own, i.e., I would like to have the pretrained GearNet-Edge-IEConv encoder before finetuning, obtain the finetuning configuration script and corresponding running command (e.g., how many GPUs/batch size were actually used in finetuning), and do the finetuning experiment on my own. Whether it is convenient for you to provide these for me? Thank you very much.
I see. The original pre-trained checkpoints were deleted by my cluster. I've pre-trained a new GearNet-Edge-IEConv recently. You can download the checkpoint from this link and have a try. Please ping me if there is any problem with the checkpoint.
For finetuning, just use the following command
python script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0] --ckpt <path_to_your_model>
I see. The original pre-trained checkpoints were deleted by my cluster. I've pre-trained a new GearNet-Edge-IEConv recently. You can download the checkpoint from this link and have a try. Please ping me if there is any problem with the checkpoint.
For finetuning, just use the following command
python script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0] --ckpt <path_to_your_model>
Thank you very much. I will have a try.
I see. The original pre-trained checkpoints were deleted by my cluster. I've pre-trained a new GearNet-Edge-IEConv recently. You can download the checkpoint from this link and have a try. Please ping me if there is any problem with the checkpoint.
For finetuning, just use the following command
python script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0] --ckpt <path_to_your_model>
Hi, it seems that the model contained in the above link is not in line with/cannot fit the model (size) in official https://zenodo.org/record/7723075 (the hidden dimensions of each layer are different), I guess the model in https://zenodo.org/record/7723075 is based on the following new implementation version of GearNet-Edge-IEConv (with extra input embedding etc).
@R.register("models.GearNetIEConv") class GearNetIEConv(nn.Module, core.Configurable):
def __init__(self, input_dim, embedding_dim, hidden_dims, num_relation, edge_input_dim=None,
batch_norm=False, activation="relu", concat_hidden=False, short_cut=True,
readout="sum", dropout=0, num_angle_bin=None, layer_norm=False, use_ieconv=False):
super(GearNetIEConv, self).__init__()
print('using GearNetIEConv.')
if not isinstance(hidden_dims, Sequence):
hidden_dims = [hidden_dims]
self.input_dim = input_dim
self.embedding_dim = embedding_dim
self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
self.dims = [embedding_dim if embedding_dim > 0 else input_dim] + list(hidden_dims)
self.edge_dims = [edge_input_dim] + self.dims[:-1]
self.num_relation = num_relation
self.concat_hidden = concat_hidden
self.short_cut = short_cut
self.num_angle_bin = num_angle_bin
self.short_cut = short_cut
self.concat_hidden = concat_hidden
self.layer_norm = layer_norm
self.use_ieconv = use_ieconv
if embedding_dim > 0:
self.linear = nn.Linear(input_dim, embedding_dim)
self.embedding_batch_norm = nn.BatchNorm1d(embedding_dim)
self.layers = nn.ModuleList()
self.ieconvs = nn.ModuleList()
for i in range(len(self.dims) - 1):
# note that these layers are from gearnet.layer instead of torchdrug.layers
self.layers.append(layer.GeometricRelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation,
None, batch_norm, activation))
if use_ieconv:
self.ieconvs.append(layer.IEConvLayer(self.dims[i], self.dims[i] // 4,
self.dims[i+1], edge_input_dim=14, kernel_hidden_dim=32))
if num_angle_bin:
self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
self.edge_layers = nn.ModuleList()
for i in range(len(self.edge_dims) - 1):
self.edge_layers.append(layer.GeometricRelationalGraphConv(
self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))
if layer_norm:
self.layer_norms = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layer_norms.append(nn.LayerNorm(self.dims[i + 1]))
self.dropout = nn.Dropout(dropout)
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)
def get_ieconv_edge_feature(self, graph):
u = torch.ones_like(graph.node_position)
u[1:] = graph.node_position[1:] - graph.node_position[:-1]
u = F.normalize(u, dim=-1)
b = torch.ones_like(graph.node_position)
b[:-1] = u[:-1] - u[1:]
b = F.normalize(b, dim=-1)
n = torch.ones_like(graph.node_position)
n[:-1] = torch.cross(u[:-1], u[1:])
n = F.normalize(n, dim=-1)
local_frame = torch.stack([b, n, torch.cross(b, n)], dim=-1)
node_in, node_out = graph.edge_list.t()[:2]
t = graph.node_position[node_out] - graph.node_position[node_in]
t = torch.einsum('ijk, ij->ik', local_frame[node_in], t)
r = torch.sum(local_frame[node_in] * local_frame[node_out], dim=1)
delta = torch.abs(graph.atom2residue[node_in] - graph.atom2residue[node_out]).float() / 6
delta = delta.unsqueeze(-1)
return torch.cat([
t, r, delta,
1 - 2 * t.abs(), 1 - 2 * r.abs(), 1 - 2 * delta.abs()
], dim=-1)
def forward(self, graph, input, all_loss=None, metric=None):
hiddens = []
layer_input = input
if self.embedding_dim > 0:
layer_input = self.linear(layer_input)
layer_input = self.embedding_batch_norm(layer_input)
if self.num_angle_bin:
line_graph = self.spatial_line_graph(graph)
edge_hidden = line_graph.node_feature.float()
else:
edge_hidden = None
ieconv_edge_feature = self.get_ieconv_edge_feature(graph)
for i in range(len(self.layers)):
# edge message passing
if self.num_angle_bin:
edge_hidden = self.edge_layers[i](line_graph, edge_hidden)
hidden = self.layers[i](graph, layer_input, edge_hidden)
# ieconv layer
if self.use_ieconv:
hidden = hidden + self.ieconvs[i](graph, layer_input, ieconv_edge_feature)
hidden = self.dropout(hidden)
if self.short_cut and hidden.shape == layer_input.shape:
hidden = hidden + layer_input
if self.layer_norm:
hidden = self.layer_norms[i](hidden)
hiddens.append(hidden)
layer_input = hidden
if self.concat_hidden:
node_feature = torch.cat(hiddens, dim=-1)
else:
node_feature = hiddens[-1]
graph_feature = self.readout(graph, node_feature)
return {
"graph_feature": graph_feature,
"node_feature": node_feature
}