Torch-Pruning
Torch-Pruning copied to clipboard
Pruning with concatenation failing
Not sure if this is intended behavior, but it looks like there might be an issue with concatenation based on the following test.
Code:
class TestModule(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.block1 = nn.Sequential(
nn.Conv2d(in_dim, in_dim, 1),
nn.BatchNorm2d(in_dim),
nn.GELU(),
nn.Conv2d(in_dim, in_dim, 1),
nn.BatchNorm2d(in_dim)
)
self.block2 = nn.Sequential(
nn.Conv2d(in_dim * 2, in_dim, 1),
nn.BatchNorm2d(in_dim)
)
def forward(self, x):
x = self.block1(x)
x = torch.cat([x, x], dim=1)
x = self.block2(x)
return x
model = TestModule(512)
pruner = tp.pruner.MagnitudePruner(
model,
dummy_input,
importance=tp.importance.MagnitudeImportance(p=2),
iterative_steps=6,
ch_sparsity=0.75,
ignored_layers=ignored_layers
)
dummy_input = torch.randn(1, 512, 7, 7)
for step in range(6):
pruner.step()
model = model.eval()
model(dummy_input)
Expected Behavior: After the first pruning stage, Block 2 will have input channels 896 and output channels 448.
Actual Behavior:
RuntimeError: Given groups=1, weight of size [448, 960, 1, 1], expected input[1, 896, 7, 7] to have 960 channels, but got 896 channels instead