model-optimization icon indicating copy to clipboard operation
model-optimization copied to clipboard

Add a default PruningPolicy that filters out any layers not supported by the API

Open annietllnd opened this issue 1 year ago • 1 comments

Hi TFMOT team! I have created a workaround for an edge-case in my project, and in my head it should be possible to have it be the default behavior in the API. Creating this feature request as a suggestion - let me know what you think!

System information

Running on TF 2.11. Unfortunately, I currently don't have the bandwidth to contribute the feature request.

Motivation

This feature request is for an implementation of the PruningPolicy that allows pruning for layers that are supported by the PruningRegistry.

Short background. When calling the prune_low_magnitude or similar functions, it's possible to ignore certain layers according to a pruning policy. By implementing the abstract class PruningPolicy, you can check that the model and layers fulfill certain requirements. One built-in implementation of this already exists, namely the PruneForLatencyOnXNNPack. A call can look like this:

model = prune_low_magnitude(
      keras.Sequential([
          layers.Dense(10, activation='relu', input_shape=(100,)),
          layers.Dense(2, activation='sigmoid')
      ]),
      pruning_policy=PruneForLatencyOnXNNPack(),
      **pruning_params)

Currently, since no pruning policy is default, the API will try and prune layers that are not compatible. In one of our use-cases, we had to implement a policy that checks if the layer is supported (to avoid trying to prune a TFOpLambda layer, as an example). Explicitly, we are safeguarding the API by skipping data that itself knows that it doesn't support. I'm suggesting to add another implementation to the API, which simply calls the supports function linked above. If possible, I'd also use it as a default value for the pruning_policy parameter (this part may come with additional issues for some use-cases, so that would be an optional part of this feature request).

I think it would help people using the API to avoid confusing bugs in edge-cases. If there's an existing way to do this that we have overlooked, I'm happy to get that feedback. Let me know if the suggestion needs elaboration.

Thank you for your time! Annie

annietllnd avatar Sep 26 '23 08:09 annietllnd

@cdh4696 Could you take a look at this? Thank you! :)

doyeonkim0 avatar Oct 04 '23 05:10 doyeonkim0