pytorch_geometric
pytorch_geometric copied to clipboard
Gradient does not propagate during network training
🐛 Describe the bug
``Hello, I am currently completing a task of fitting the radial distribution function of water. I have modified the model code provided by them to the code in your project based on the Schnet source code https://github.com/torchmd/mdgrad, and there have been no other changes. But after printing the gradient, it was found that there was an issue of gradient not propagating when running your model. Why is this?
loss = loss_mse
optimizer.zero_grad()
loss.backward()
print("epoch {} | loss: {:.5f}".format(i, loss.item()) )
optimizer.step()
for name, parms in net.named_parameters():
print('-->name:', name, '-->grad_requirs:', parms.requires_grad, '--weight', torch.mean(parms.data),
' -->grad_value:', torch.mean(parms.grad))
scheduler.step(loss)
The results of using mdgard:
The result of using torch_geometric:
What I can guarantee is that except for replacing the network model and adjusting the input content, no other modifications have been made. And besides the same issue when running SchNet, I also encountered the same problem when trying to switch to DimeNetPlusPlus. This question is very important to me, and I look forward to your reply!
Versions
Python version: 3.11.7
I am running
data = Data(
z=torch.randint(1, 10, (20, )),
pos=torch.randn(20, 3),
)
model = SchNet(
hidden_channels=16,
num_filters=16,
num_interactions=2,
num_gaussians=10,
cutoff=6.0,
dipole=True,
atomref=torch.randn(100, 1) if use_atomref else None,
)
out = model(data.z, data.pos)
out.mean().backward()
for name, params in model.named_parameters():
if params is not None and params.grad is not None:
print(name, params.grad.mean())
and receive
embedding.weight tensor(-0.0044)
interactions.0.mlp.0.weight tensor(0.0598)
interactions.0.mlp.0.bias tensor(0.2427)
interactions.0.mlp.2.weight tensor(0.0053)
interactions.0.mlp.2.bias tensor(0.3005)
interactions.0.conv.lin1.weight tensor(0.0092)
interactions.0.conv.lin2.weight tensor(0.0024)
interactions.0.conv.lin2.bias tensor(0.1793)
interactions.0.lin.weight tensor(-0.0106)
interactions.0.lin.bias tensor(-0.1433)
interactions.1.mlp.0.weight tensor(0.0350)
interactions.1.mlp.0.bias tensor(0.1370)
interactions.1.mlp.2.weight tensor(0.0284)
interactions.1.mlp.2.bias tensor(1.2006)
interactions.1.conv.lin1.weight tensor(0.0710)
interactions.1.conv.lin2.weight tensor(-0.0016)
interactions.1.conv.lin2.bias tensor(0.0699)
interactions.1.lin.weight tensor(-0.0570)
interactions.1.lin.bias tensor(-0.2495)
lin1.weight tensor(-0.0696)
lin1.bias tensor(0.4007)
lin2.weight tensor(-1.2804)
lin2.bias tensor(-2.4694)
Can you confirm that this breaks for you?