dgl
dgl copied to clipboard
Error During loss.backward() in DGL Graph with idtype i32 During GraphSAGE Mini-Batch CPU Training – Resolved with idtype i64 or device=cuda
🐛 Bug
When training a GraphSAGE model using mini-batches on a CPU with a DGL graph of idtype=i32, the execution is interrupted abnormally during loss.backward(). However, switching to idtype=i64 or using device=cuda resolves the issue.
To Reproduce
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
from dgl.dataloading import DataLoader,NeighborSampler
def create_synthetic_graph():
num_nodes = 100
num_edges = 500
src_nodes = torch.randint(0, num_nodes, (num_edges,))
dst_nodes = torch.randint(0, num_nodes, (num_edges,))
mask = src_nodes != dst_nodes
src_nodes = src_nodes[mask]
dst_nodes = dst_nodes[mask]
g = dgl.graph((src_nodes, dst_nodes))
g.ndata['feat'] = torch.randn(g.num_nodes(), 10)
g.ndata['label'] = torch.randint(0, 5, (g.num_nodes(),))
train_mask = torch.zeros(g.num_nodes(), dtype=torch.bool)
val_mask = torch.zeros(g.num_nodes(), dtype=torch.bool)
test_mask = torch.zeros(g.num_nodes(), dtype=torch.bool)
indices = torch.randperm(g.num_nodes())
train_idx = indices[:int(0.6 * g.num_nodes())]
val_idx = indices[int(0.6 * g.num_nodes()):int(0.8 * g.num_nodes())]
test_idx = indices[int(0.8 * g.num_nodes()):]
train_mask[train_idx] = True
val_mask[val_idx] = True
test_mask[test_idx] = True
g.ndata['train_mask'] = train_mask
g.ndata['val_mask'] = val_mask
g.ndata['test_mask'] = test_mask
g = g.int()
return g
class GraphSAGE(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, aggregator_type='mean'):
super().__init__()
self.conv1 = dglnn.SAGEConv(
in_feats=in_feats,
out_feats=hid_feats,
aggregator_type=aggregator_type,
activation=F.relu
)
self.conv2 = dglnn.SAGEConv(
in_feats=hid_feats,
out_feats=out_feats,
aggregator_type=aggregator_type
)
def forward(self, blocks, x):
x = self.conv1(blocks[0], x)
x = self.conv2(blocks[1], x)
return x
def train(model, g, train_loader, optimizer, device):
model.train()
total_loss = 0
total_correct = 0
total_samples = 0
for input_nodes, output_nodes, blocks in train_loader:
blocks = [b.to(device) for b in blocks]
features = blocks[0].srcdata['feat']
labels = blocks[-1].dstdata['label']
logits = model(blocks, features)
loss = F.cross_entropy(logits, labels)
optimizer.zero_grad()
# if torch.isnan(loss).any() or torch.isinf(loss).any():
# print("Loss contains NaN/Inf!")
print("ok1")
loss.backward()
print("ok2")
optimizer.step()
total_loss += loss.item() * len(output_nodes)
total_correct += (logits.argmax(1) == labels).sum().item()
total_samples += len(output_nodes)
return total_loss / total_samples, total_correct / total_samples
def evaluate(model, g, loader, device):
model.eval()
total_correct = 0
total_samples = 0
with torch.no_grad():
for input_nodes, output_nodes, blocks in loader:
blocks = [b.to(device) for b in blocks]
features = blocks[0].srcdata['feat']
labels = blocks[-1].dstdata['label']
logits = model(blocks, features)
total_correct += (logits.argmax(1) == labels).sum().item()
total_samples += len(output_nodes)
return total_correct / total_samples
def main():
torch.manual_seed(42)
g = create_synthetic_graph()
device = "cpu"#torch.device('cuda' if torch.cuda.is_available() else 'cpu')
in_feats = g.ndata['feat'].shape[1]
hid_feats = 16
out_feats = 5
num_epochs = 50
batch_size = 32
model = GraphSAGE(in_feats, hid_feats, out_feats).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
train_nids = g.ndata['train_mask'].nonzero().squeeze()
val_nids = g.ndata['val_mask'].nonzero().squeeze()
test_nids = g.ndata['test_mask'].nonzero().squeeze()
sampler = NeighborSampler([10, 10]) # 两层采样,每层采样10个邻居
train_loader = DataLoader(
g,
train_nids.to(torch.int32),
sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=0
)
val_loader = DataLoader(
g,
val_nids.to(torch.int32),
sampler,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=0,
)
test_loader = DataLoader(
g,
test_nids.to(torch.int32),
sampler,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=0,
)
best_val_acc = 0
for epoch in range(num_epochs):
train_loss, train_acc = train(model, g, train_loader, optimizer, device)
val_acc = evaluate(model, g, val_loader, device)
print(f'Epoch {epoch:02d}: '
f'Train Loss: {train_loss:.4f}, '
f'Train Acc: {train_acc:.4f}, '
f'Val Acc: {val_acc:.4f}')
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
model.load_state_dict(torch.load('best_model.pth'))
test_acc = evaluate(model, g, test_loader, device)
print(f'Test Accuracy: {test_acc:.4f}')
if __name__ == '__main__':
main()
Theoretically, it should print the end of Test Accuracy, but it was abnormally interrupted when running loss.backward, and there was no error prompt.
Environment
- DGL Version (e.g., 1.0): 2.2.1
- Backend Library & Version (e.g., PyTorch 0.4.1, MXNet/Gluon 1.3): 2.3.0
- OS (e.g., Linux): linux
- How you installed DGL (
conda,pip, source): conda - Python version: 3.12.3
This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you