DATM
DATM copied to clipboard
CIFAR100 significantly increases memory usage
When running DATM.py on CIFAR10 and CIFAR100, CIFAR10 hardly increases memory usage, but CIFAR100 significantly increases memory usage each time the following code is executed. What could be the reason for this?
for step in range(args.syn_steps):
.......
grad = torch.autograd.grad(ce_loss, student_params[-1], create_graph=True)[0]
.......
Maybe because the surrogate model used for cifar-100 has 10x more parameters in the fully connected layer 🤔.