Dose it support diffusion model now?
Thanks for your great job!
I want to apply it in the VAE model but it got some problems.
Here is my code:
import torch
import torch_pruning as tp
from torch import nn
from PIL import Image
from diffusers import AutoencoderKL
from tqdm.auto import tqdm
import numpy as np
device = "cuda:6" if torch.cuda.is_available() else "cpu"
vae_model_id = "stabilityai/stable-diffusion-2-inpainting/vae" vae = AutoencoderKL.from_pretrained(vae_model_id).to(torch.float16).to(device) model = vae example_inputs = torch.randn(1, 3, 224, 224, dtype=torch.float16).to(device) imp = tp.importance.MagnitudeImportance() ignored_layers = [] for m in model.modules(): if isinstance(m, torch.nn.Linear) and m.out_features == 1000: ignored_layers.append(m) # DO NOT prune the final classifier!
def forward_fn(model, inputs): return model(inputs)[0]
iterative_steps = 3 # progressive pruning pruner = tp.pruner.MagnitudePruner( model.encoder, example_inputs, importance=imp, iterative_steps=iterative_steps, ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} ignored_layers=ignored_layers, forward_fn=forward_fn, )
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) pruner.step() macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")`
`
AssertionError Traceback (most recent call last) Cell In[5], line 1 ----> 1 macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 2 print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch_pruning/utils/op_counter.py:35, in count_ops_and_params(model, example_inputs, layer_wise) 33 _ = flops_model(**example_inputs) 34 else: ---> 35 _ = flops_model(example_inputs) 36 flops_count, params_count, _layer_flops, _layer_params = flops_model.compute_average_flops_cost() 37 layer_flops = {}
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs)
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1844, in Module._call_impl(self, *args, **kwargs) 1841 return inner() 1843 try: -> 1844 return inner() 1845 except Exception: 1846 # run always called hooks if they have not already been run 1847 # For now only forward hooks have the always_call option but perhaps 1848 # this functionality should be added to full backward hooks as well. 1849 for hook_id, hook in _global_forward_hooks.items():
File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1790, in Module._call_impl.
AssertionError: Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...`
It seems like the error was raised due to a change in the channel shape.