dgl icon indicating copy to clipboard operation
dgl copied to clipboard

Error During loss.backward() in DGL Graph with idtype i32 During GraphSAGE Mini-Batch CPU Training – Resolved with idtype i64 or device=cuda

Open yfismine opened this issue 7 months ago • 1 comments

🐛 Bug

When training a GraphSAGE model using mini-batches on a CPU with a DGL graph of idtype=i32, the execution is interrupted abnormally during loss.backward(). However, switching to idtype=i64 or using device=cuda resolves the issue.

To Reproduce

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
from dgl.dataloading import DataLoader,NeighborSampler



def create_synthetic_graph():
    num_nodes = 100
    num_edges = 500
    src_nodes = torch.randint(0, num_nodes, (num_edges,))
    dst_nodes = torch.randint(0, num_nodes, (num_edges,))

    mask = src_nodes != dst_nodes
    src_nodes = src_nodes[mask]
    dst_nodes = dst_nodes[mask]

    g = dgl.graph((src_nodes, dst_nodes))

    g.ndata['feat'] = torch.randn(g.num_nodes(), 10)

    g.ndata['label'] = torch.randint(0, 5, (g.num_nodes(),))

    train_mask = torch.zeros(g.num_nodes(), dtype=torch.bool)
    val_mask = torch.zeros(g.num_nodes(), dtype=torch.bool)
    test_mask = torch.zeros(g.num_nodes(), dtype=torch.bool)

    indices = torch.randperm(g.num_nodes())
    train_idx = indices[:int(0.6 * g.num_nodes())]
    val_idx = indices[int(0.6 * g.num_nodes()):int(0.8 * g.num_nodes())]
    test_idx = indices[int(0.8 * g.num_nodes()):]

    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True

    g.ndata['train_mask'] = train_mask
    g.ndata['val_mask'] = val_mask
    g.ndata['test_mask'] = test_mask
    g = g.int()
    return g

class GraphSAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, aggregator_type='mean'):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats,
            out_feats=hid_feats,
            aggregator_type=aggregator_type,
            activation=F.relu
        )
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats,
            out_feats=out_feats,
            aggregator_type=aggregator_type
        )

    def forward(self, blocks, x):
        x = self.conv1(blocks[0], x)
        x = self.conv2(blocks[1], x)
        return x


def train(model, g, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for input_nodes, output_nodes, blocks in train_loader:
        blocks = [b.to(device) for b in blocks]
        features = blocks[0].srcdata['feat']
        labels = blocks[-1].dstdata['label']
        logits = model(blocks, features)
        loss = F.cross_entropy(logits, labels)

        optimizer.zero_grad()
        # if torch.isnan(loss).any() or torch.isinf(loss).any():
        #     print("Loss contains NaN/Inf!")
        print("ok1")
        loss.backward()
        print("ok2")
        optimizer.step()

        total_loss += loss.item() * len(output_nodes)
        total_correct += (logits.argmax(1) == labels).sum().item()
        total_samples += len(output_nodes)

    return total_loss / total_samples, total_correct / total_samples

def evaluate(model, g, loader, device):
    model.eval()
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for input_nodes, output_nodes, blocks in loader:
            blocks = [b.to(device) for b in blocks]
            features = blocks[0].srcdata['feat']
            labels = blocks[-1].dstdata['label']

            logits = model(blocks, features)
            total_correct += (logits.argmax(1) == labels).sum().item()
            total_samples += len(output_nodes)

    return total_correct / total_samples

def main():
    torch.manual_seed(42)

    g = create_synthetic_graph()

    device = "cpu"#torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    in_feats = g.ndata['feat'].shape[1]
    hid_feats = 16
    out_feats = 5
    num_epochs = 50
    batch_size = 32

    model = GraphSAGE(in_feats, hid_feats, out_feats).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    train_nids = g.ndata['train_mask'].nonzero().squeeze()
    val_nids = g.ndata['val_mask'].nonzero().squeeze()
    test_nids = g.ndata['test_mask'].nonzero().squeeze()

    sampler = NeighborSampler([10, 10])  # 两层采样,每层采样10个邻居
    train_loader = DataLoader(
        g,
        train_nids.to(torch.int32),
        sampler,
        batch_size=batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=0
    )

    val_loader = DataLoader(
        g,
        val_nids.to(torch.int32),
        sampler,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0,
    )

    test_loader = DataLoader(
        g,
        test_nids.to(torch.int32),
        sampler,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0,
    )

    best_val_acc = 0
    for epoch in range(num_epochs):
        train_loss, train_acc = train(model, g, train_loader, optimizer, device)
        val_acc = evaluate(model, g, val_loader, device)

        print(f'Epoch {epoch:02d}: '
              f'Train Loss: {train_loss:.4f}, '
              f'Train Acc: {train_acc:.4f}, '
              f'Val Acc: {val_acc:.4f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

    model.load_state_dict(torch.load('best_model.pth'))
    test_acc = evaluate(model, g, test_loader, device)
    print(f'Test Accuracy: {test_acc:.4f}')

if __name__ == '__main__':
    main()

Theoretically, it should print the end of Test Accuracy, but it was abnormally interrupted when running loss.backward, and there was no error prompt.

Image

Environment

  • DGL Version (e.g., 1.0): 2.2.1
  • Backend Library & Version (e.g., PyTorch 0.4.1, MXNet/Gluon 1.3): 2.3.0
  • OS (e.g., Linux): linux
  • How you installed DGL (conda, pip, source): conda
  • Python version: 3.12.3

yfismine avatar May 23 '25 01:05 yfismine

This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you

github-actions[bot] avatar Jun 23 '25 01:06 github-actions[bot]