Is there a way to automatically ignore layers with unsupported ops?
Hi I am trying to prune a slight variation of Resnet model to get discriminative features
class ResNetArc_Classifier(nn.Module):
def __init__(self, model_type:str="resnet50", class_count:int=1000):
super(ResNetArc_Classifier, self).__init__()
resnet_model_fn = getattr(torchvision.models, model_type)
self.inp_batch_norm = nn.BatchNorm2d(1)
self.feature_network = resnet_model_fn()
self.feature_network.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.feature_network.fc = nn.Identity()
self.bn_final = nn.BatchNorm1d(2048)
self.mapper = nn.Linear(2048, class_count, bias=False)
self.s = 8
self.m = 0
self.update_margin(self.m, self.s)
def update_margin(self, m = 0 , s=8):
self.s = s
self.m = m
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
def forward(self, x ,y=None):
assert self.s!=0
x_norm = self.inp_batch_norm(x)
feat = self.feature_network(x_norm)
feat_norm = self.bn_final(feat)
feat_l2 = F.normalize(feat_norm, dim=1)
cosine_dist = torch.matmul(
feat_l2,
F.normalize(self.mapper.weight.T, dim=0)
)
sine_dist = torch.sqrt(1 - torch.square(cosine_dist))#.clamp(-1,1)
out_dist = cosine_dist
if y!=None :
margined_cosine_dist = cosine_dist*self.cos_m - sine_dist*self.sin_m # cos(a+b) eqn
margined_cosine_dist = torch.where(
cosine_dist>0,
margined_cosine_dist,
cosine_dist
) # easy margin
out_dist = margined_cosine_dist*y + cosine_dist*(1-y)
out = self.s * out_dist
return 8 * out
When I try to prune the above model using the following snippet
from decimal import Decimal
from aimet_torch.defs import GreedySelectionParameters, ChannelPruningParameters
from aimet_common.defs import CompressionScheme, CostMetric
from simple_model import EmptyModule
greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.9),
num_comp_ratio_candidates=3)
data_loader_pl.setup(23)
modules_to_ignore = [pl_model.pytorch_model.feature_network.fc]
auto_params = ChannelPruningParameters.AutoModeParams(greedy_select_params=greedy_params,modules_to_ignore=modules_to_ignore)
params = ChannelPruningParameters(data_loader=data_loader_pl.val_dataloader(),
num_reconstruction_samples=10,
allow_custom_downsample_ops=False,
mode=ChannelPruningParameters.Mode.auto,
params=auto_params)
def eval_callback(model, iterations, use_cuda:bool):
pl_model = LightningWordClassifier()
pl_model.pytorch_model = model
trainer = pl.Trainer(
precision=16,
accelerator="gpu",
devices=1,
deterministic=True,
)
results = trainer.test(model = pl_model, dataloaders = data_loader_pl)
return results["test_top1"]
eval_iterations = 1
compress_scheme = CompressionScheme.channel_pruning
cost_metric = CostMetric.mac
from aimet_torch.compress import ModelCompressor
compressed_model, comp_stats = ModelCompressor.compress_model(model=pl_model.pytorch_model, eval_callback=eval_callback,
eval_iterations=eval_iterations,
input_shape=(1, 1, 149, 64),
compress_scheme=compress_scheme,
cost_metric=cost_metric,
parameters=params,)
I get the following error
WARNING:param.ParameterizedMetaclass: Use method 'params' via param namespace
2023-02-27 06:48:45,849 - param.ParameterizedMetaclass - WARNING - Use method 'params' via param namespace
2023-02-27 06:48:46,491 - CompRatioSelect - INFO - Analyzing compression ratio: 0.3333333333333333333333333333 =====================>
2023-02-27 06:48:47,048 - Winnow - ERROR - Unsupported op_type norm, dotted norm_199, input_ops: [Split_16]
NotImplementedError: Unsupported op_type norm, dotted norm_199, input_ops: [Split_16]
I speculate that the above error is from the F.normalize calls, how can I fix this , Is there a method through which I could automatically ignore such layers or ops?
@quic-hitameht could you help answer this question. Thanks
same issue here, is there a way to workaround it
I have the same problem
same issue! Why do this question not be answered even after eight month?
@TheSeriousProgrammer, sorry for the late reply, you could add the modules you want to ignore in the list of modules to ignore. Currently there is no way of ignoring a certain op type.
oh ok, can we get the parent module information of the op as well when we get such a error? Or some other trace through which we can manually figure out the layer?
Unfortunately we don't have that utility ATM, but I will create an issue for it. If you would like to contribute to the same, please let us know. Thanks