benchmarking-gnns icon indicating copy to clipboard operation
benchmarking-gnns copied to clipboard

Optimize train_CSL_graph_classification.py

Open mzamini92 opened this issue 1 year ago • 0 comments

In both the training and evaluation loops, there are unnecessary calls to loss.detach().item(). You can calculate the loss without detaching it and only detach it if necessary at a later stage. In the train_epoch_dense function, you can remove the manual batch handling using (iter % batch_size) and instead rely on the batch_size parameter of the data_loader. The DataLoader automatically handles the batch iteration for you. In both the train_epoch_sparse and train_epoch_dense functions, you can move the optimizer.zero_grad() call outside the loop, just before the loop starts. This will avoid unnecessary repeated calls to zero_grad().

mzamini92 avatar Jun 22 '23 04:06 mzamini92