nncf
nncf copied to clipboard
[WIP][Pruning algo] Filter counting fix for linear layers
Changes
Redundant input shape param was removed from count_flops_and_weights_per_node
function.
count_filters_num
now always takes into account pruned linear layers.
count_flops_and_weights_per_node
now correctly computes statistics for model with pruned linear layers.
THIS WILL CHANGE PRUNING LEVEL OF MODELS WHICH CONTAINS PRUNABLE LINEAR LAYERS, PRUNED BY FLOPS
Reason for changes
collect_input_shapes
function isn't protected from case when input edges don't exist (unlike collect_output_shapes
, which process nodes without output edges correctly). Turns out that output of collect_input_shaped
used only in count_flops_and_weights_per_node
function and could be easily replaced by linear layers in_features
and out_features
attributes. Besides count_flops_and_weights_per_node
function didn't take into account pruned rows/cols of linear layer, which was fixed by getting parameters from input_channels
and output_channels
params.
Related tickets
86889
Tests
TODO