torch-pruning icon indicating copy to clipboard operation
torch-pruning copied to clipboard

Pruning methods for pytorch with an optimizer-like interface

torch-pruning

Pruning methods for pytorch with an optimizer-like interface

usage

import torch
import pruning

net = # an arbitrary pytorch nn.Module instance
dataloader = # some pytorch dataloader instance

optimizer = torch.optim.SGD(net.parameters(), 0.01, weight_decay=1e-5)
# Init pruning method in the same way as optimizer
pruning = pruning.MagnitudePruning(net.parameters(), 0.1, local=True,
                                   exclude_biases=True)

# Save initial parameters for later
w_0 = pruning.clone_params()

def train(net, dataloader, n_epochs=1):
    # Some standard training loop ...
    for epoch in range(n_epochs):
        for x, y in dataloader:
            # Do actually set *pruned* weights to zero
            pruning.zero_params()
            y_hat = net(x)
            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    

# Train epoch
train(net, dataloader, n_epochs=100)
# Do prune!
pruning.step()
# Rewind parameters to their values at init
pruning.rewind(w_0)
# Train the pruned model
train(net, dataloader, n_epochs=100)

# Do train and prune in a for loop, and then
# check if you have found a winning ticket
# ...