benchmarking-gnns
benchmarking-gnns copied to clipboard
Optimize train_CSL_graph_classification.py
In both the training and evaluation loops, there are unnecessary calls to loss.detach().item()
. You can calculate the loss without detaching it and only detach it if necessary at a later stage. In the train_epoch_dense function, you can remove the manual batch handling using (iter % batch_size)
and instead rely on the batch_size
parameter of the data_loader
. The DataLoader automatically handles the batch iteration for you. In both the train_epoch_sparse
and train_epoch_dense
functions, you can move the optimizer.zero_grad()
call outside the loop, just before the loop starts. This will avoid unnecessary repeated calls to zero_grad()
.