Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

Pruning with concatenation failing

Open WAvery4 opened this issue 1 year ago • 10 comments

Not sure if this is intended behavior, but it looks like there might be an issue with concatenation based on the following test.

Code:

class TestModule(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_dim, in_dim, 1),
            nn.BatchNorm2d(in_dim),
            nn.GELU(),
            nn.Conv2d(in_dim, in_dim, 1),
            nn.BatchNorm2d(in_dim)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_dim * 2, in_dim, 1),
            nn.BatchNorm2d(in_dim)
        )
        
    def forward(self, x):
        x = self.block1(x)
        x = torch.cat([x, x], dim=1)
        x = self.block2(x)
        return x

model = TestModule(512)       
        
pruner = tp.pruner.MagnitudePruner(
    model,
    dummy_input,
    importance=tp.importance.MagnitudeImportance(p=2),
    iterative_steps=6,
    ch_sparsity=0.75,
    ignored_layers=ignored_layers
)

dummy_input = torch.randn(1, 512, 7, 7)

for step in range(6):
    pruner.step()
    
    model = model.eval()
    model(dummy_input)

Expected Behavior: After the first pruning stage, Block 2 will have input channels 896 and output channels 448.

Actual Behavior:

RuntimeError: Given groups=1, weight of size [448, 960, 1, 1], expected input[1, 896, 7, 7] to have 960 channels, but got 896 channels instead

WAvery4 avatar Apr 10 '23 19:04 WAvery4