Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

Dose it support diffusion model now?

Open rex-29 opened this issue 1 year ago • 0 comments

Thanks for your great job! I want to apply it in the VAE model but it got some problems. Here is my code: import torch import torch_pruning as tp from torch import nn from PIL import Image from diffusers import AutoencoderKL from tqdm.auto import tqdm import numpy as np

device = "cuda:6" if torch.cuda.is_available() else "cpu"

vae_model_id = "stabilityai/stable-diffusion-2-inpainting/vae" vae = AutoencoderKL.from_pretrained(vae_model_id).to(torch.float16).to(device) model = vae example_inputs = torch.randn(1, 3, 224, 224, dtype=torch.float16).to(device) imp = tp.importance.MagnitudeImportance() ignored_layers = [] for m in model.modules(): if isinstance(m, torch.nn.Linear) and m.out_features == 1000: ignored_layers.append(m) # DO NOT prune the final classifier!

def forward_fn(model, inputs): return model(inputs)[0]

iterative_steps = 3 # progressive pruning pruner = tp.pruner.MagnitudePruner( model.encoder, example_inputs, importance=imp, iterative_steps=iterative_steps, ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} ignored_layers=ignored_layers, forward_fn=forward_fn, )

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) pruner.step() macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")`

`

AssertionError Traceback (most recent call last) Cell In[5], line 1 ----> 1 macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 2 print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs) 113 @functools.wraps(func) 114 def decorate_context(*args, **kwargs): 115 with ctx_factory(): --> 116 return func(*args, **kwargs)

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch_pruning/utils/op_counter.py:35, in count_ops_and_params(model, example_inputs, layer_wise) 33 _ = flops_model(**example_inputs) 34 else: ---> 35 _ = flops_model(example_inputs) 36 flops_count, params_count, _layer_flops, _layer_params = flops_model.compute_average_flops_cost() 37 layer_flops = {}

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1844, in Module._call_impl(self, *args, **kwargs) 1841 return inner() 1843 try: -> 1844 return inner() 1845 except Exception: 1846 # run always called hooks if they have not already been run 1847 # For now only forward hooks have the always_call option but perhaps 1848 # this functionality should be added to full backward hooks as well. 1849 for hook_id, hook in _global_forward_hooks.items():

File ~/.conda/envs/bydeng/lib/python3.11/site-packages/torch/nn/modules/module.py:1790, in Module._call_impl..inner() 1787 bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1788 args = bw_hook.setup_input_hook(args) -> 1790 result = forward_call(*args, **kwargs) 1791 if _global_forward_hooks or self._forward_hooks: 1792 for hook_id, hook in ( 1793 *_global_forward_hooks.items(), 1794 *self._forward_hooks.items(), 1795 ): 1796 # mark that always called hook is run ... --> 136 assert hidden_states.shape[1] == self.channels 138 if self.norm is not None: 139 hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

AssertionError: Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...`

It seems like the error was raised due to a change in the channel shape.

rex-29 avatar Dec 24 '24 01:12 rex-29