openfold
openfold copied to clipboard
Uninitialized weights: `extra_msa_stack.blocks.3. msa_att_row` and related for OpenFold Multimer
Hi, I've been running some folding with OpenFold Multimer using custom inputs, and have found some strange discrepancies between OpenFold outputs and AlphaFold outputs. In the process of tracking this down, I checked to ensure that all the weights are being properly loaded into the model, and found:
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
cfg = model_config("model_1_multimer_v3")
model = AlphaFold(cfg)
import_jax_weights_(model, "<my_path>/params_model_1_multimer_v3.npz", version="model_1_multimer_v3")
for name, param in model.named_parameters():
if torch.sum(param) == 0:
print(f"Uninitialized weight: {name}, shape: {param.shape}")
=== OUTPUT ===
Uninitialized weight: extra_msa_stack.blocks.3.msa_att_row.layer_norm_m.bias, shape: torch.Size([64])
Uninitialized weight: extra_msa_stack.blocks.3.msa_att_row.layer_norm_z.bias, shape: torch.Size([128])
Uninitialized weight: extra_msa_stack.blocks.3.msa_att_row.mha.linear_o.weight, shape: torch.Size([64, 64])
Uninitialized weight: extra_msa_stack.blocks.3.msa_att_row.mha.linear_o.bias, shape: torch.Size([64])
Uninitialized weight: extra_msa_stack.blocks.3.msa_att_row.mha.linear_g.weight, shape: torch.Size([64, 64])
Uninitialized weight: extra_msa_stack.blocks.3.msa_transition.layer_norm.bias, shape: torch.Size([64])
Uninitialized weight: extra_msa_stack.blocks.3.msa_transition.linear_1.bias, shape: torch.Size([256])
Uninitialized weight: extra_msa_stack.blocks.3.msa_transition.linear_2.weight, shape: torch.Size([64, 256])
Uninitialized weight: extra_msa_stack.blocks.3.msa_transition.linear_2.bias, shape: torch.Size([64])
Uninitialized weight: extra_msa_stack.blocks.3.msa_att_col.layer_norm_m.bias, shape: torch.Size([64])
Uninitialized weight: extra_msa_stack.blocks.3.msa_att_col.global_attention.linear_g.weight, shape: torch.Size([64, 64])
Uninitialized weight: extra_msa_stack.blocks.3.msa_att_col.global_attention.linear_o.weight, shape: torch.Size([64, 64])
Uninitialized weight: extra_msa_stack.blocks.3.msa_att_col.global_attention.linear_o.bias, shape: torch.Size([64])
Before I spend more time exploring other rabbit holes, I wanted to check whether this is intentional behavior. The same does not occur when I import the monomer weights.