Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

test regularize error when I replace the densenet121 with unet in tests/test_regularization.py file

Open wangty537 opened this issue 1 year ago • 1 comments

this is the code, can run succesfully on densenet, but error on unet, how can I solve this , looking forward for your reply, thank you

import sys, os



import torch
from torchvision.models import densenet121 as entry
import torch_pruning as tp
from torch import nn
import torch.nn.functional as F

import torch
import torch.nn as nn


class UNetSeeInDark(nn.Module):
    def __init__(self, in_channels=4, out_channels=4):
        super(UNetSeeInDark, self).__init__()

        # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.conv1_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

        self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

        self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

        self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)

        self.conv10_1 = nn.Conv2d(32, out_channels, kernel_size=1, stride=1)

    def forward(self, x):
        conv1 = self.lrelu(self.conv1_1(x))
        conv1 = self.lrelu(self.conv1_2(conv1))
        pool1 = self.pool1(conv1)

        conv2 = self.lrelu(self.conv2_1(pool1))
        conv2 = self.lrelu(self.conv2_2(conv2))
        pool2 = self.pool1(conv2)

        conv3 = self.lrelu(self.conv3_1(pool2))
        conv3 = self.lrelu(self.conv3_2(conv3))
        pool3 = self.pool1(conv3)

        conv4 = self.lrelu(self.conv4_1(pool3))
        conv4 = self.lrelu(self.conv4_2(conv4))
        pool4 = self.pool1(conv4)

        conv5 = self.lrelu(self.conv5_1(pool4))
        conv5 = self.lrelu(self.conv5_2(conv5))

        up6 = self.upv6(conv5)
        up6 = torch.cat([up6, conv4], 1)
        conv6 = self.lrelu(self.conv6_1(up6))
        conv6 = self.lrelu(self.conv6_2(conv6))

        up7 = self.upv7(conv6)
        up7 = torch.cat([up7, conv3], 1)
        conv7 = self.lrelu(self.conv7_1(up7))
        conv7 = self.lrelu(self.conv7_2(conv7))

        up8 = self.upv8(conv7)
        up8 = torch.cat([up8, conv2], 1)
        conv8 = self.lrelu(self.conv8_1(up8))
        conv8 = self.lrelu(self.conv8_2(conv8))

        up9 = self.upv9(conv8)
        up9 = torch.cat([up9, conv1], 1)
        conv9 = self.lrelu(self.conv9_1(up9))
        conv9 = self.lrelu(self.conv9_2(conv9))

        conv10 = self.conv10_1(conv9)
        # out = nn.functional.pixel_shuffle(conv10, 2)
        out = conv10
        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
                if m.bias is not None:
                    m.bias.data.normal_(0.0, 0.02)
            if isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0.0, 0.02)

    def lrelu(self, x):
        outt = torch.max(0.2 * x, x)
        return outt


def reg_pruner():

    device = torch.device('cpu')
    model_file = r'D:\checkpoint_8000.pth'
    model = UNetSeeInDark(3, 3)
    model.load_state_dict(torch.load(model_file, map_location=device))
    model = model.to(device)

    #model = entry(pretrained=True)
    print(model)
    # Global metrics
    example_inputs = torch.randn(1, 3, 224, 224).to(device)

    for imp_cls, pruner_cls in [
        [tp.importance.GroupNormImportance, tp.pruner.GroupNormPruner],
        [tp.importance.BNScaleImportance, tp.pruner.BNScalePruner],
        [tp.importance.GroupNormImportance, tp.pruner.GrowingRegPruner],
    ]:
        print('\n#####################################\n\n\n\n')
        print(imp_cls, pruner_cls)
        imp = imp_cls()
        ignored_layers = []
        # DO NOT prune the final classifier!
        for m in model.modules():
            if isinstance(m, torch.nn.Conv2d) and m.out_channels == 3:
                ignored_layers.append(m)
        iterative_steps = 5
        pruner = pruner_cls(
            model,
            example_inputs,
            importance=imp,
            global_pruning=True,
            iterative_steps=iterative_steps,
            pruning_ratio=0.5,
            # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
            ignored_layers=ignored_layers,
        )

        for i in range(iterative_steps):
            model(example_inputs).sum().backward()
            grad_dict = {}
            for p in model.parameters():
                if p.grad is not None:
                    grad_dict[p] = p.grad.clone()
                else:
                    grad_dict[p] = None
            pruner.regularize(model)
            for name, p in model.named_parameters():
                if p.grad is not None and grad_dict[p] is not None:
                    print(name, (grad_dict[p] - p.grad).abs().sum())
                else:
                    print(name, "has no grad")
            pruner.step()


if __name__ == "__main__":
    reg_pruner()

wangty537 avatar Oct 27 '23 02:10 wangty537

Thanks for the issue. Will check it ASAP.

VainF avatar Oct 28 '23 10:10 VainF