Lottery-Ticket-Hypothesis-in-Pytorch icon indicating copy to clipboard operation
Lottery-Ticket-Hypothesis-in-Pytorch copied to clipboard

Freeze pruned weights method not efficient

Open guoyuntu opened this issue 4 years ago • 1 comments
trafficstars

In 'main.py' line 257 - 262, the author used the following codes to freeze the pruned weights:

for name, p in model.named_parameters():
        if 'weight' in name:
            tensor = p.data.cpu().numpy()
            grad_tensor = p.grad.data.cpu().numpy()
            grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
            p.grad.data = torch.from_numpy(grad_tensor).to(device)

which causes a heavy burden for CPU2GPU I/O. I will recommend conducting the freezing operation on GPU directly, the following codes helps:

    for name, p in model.named_parameters():
        if 'weight' in name:
            tensor = p.data
            grad_tensor = p.grad
            grad_tensor = torch.where(tensor.abs() < EPS, torch.zeros_like(grad_tensor), grad_tensor)
            p.grad.data = grad_tensor

guoyuntu avatar Mar 11 '21 09:03 guoyuntu

A batch size of 200 on mnist + lenet5 went from 12 seconds per epoch to 6 with your changes. Thank you!

bainro avatar Dec 24 '21 07:12 bainro