keras-surgeon icon indicating copy to clipboard operation
keras-surgeon copied to clipboard

Support for BatchNorm

Open MarvinKlemp opened this issue 3 years ago • 2 comments

Once a model contains BatchNorm layers, keras-surgeon fails.

Minimal code to reproduce the error:

#%%

import tensorflow as tf
from kerassurgeon.operations import delete_channels

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, kernel_size=(3,3), padding="same", strides=(1, 1), activation="relu"),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(8, kernel_size=(3,3), padding="same", strides=(1, 1), activation="relu"),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(2, kernel_size=(3,3), padding="same", strides=(1, 1), activation="relu"),
])
model.build((8, 256, 64, 2))

layer_idx = 0
filter_idx = 0
layer = model.layers[0]
model = delete_channels(model, layer, [filter_idx])

Stack Trace

KeyError                                  Traceback (most recent call last)

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
    233             try:
--> 234                 output_mask = mask_map[node_output]
    235                 logging.debug('bottomed out at a model input')

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/_utils/tensor_dict.py in __getitem__(self, item)
     24         try:
---> 25             return super().__getitem__(item.ref())
     26         except AttributeError:

KeyError: <Reference wrapping <KerasTensor: shape=(8, 256, 64, 2) dtype=float32 (created by layer 'conv2d_5')>>


During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
    233             try:
--> 234                 output_mask = mask_map[node_output]
    235                 logging.debug('bottomed out at a model input')

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/_utils/tensor_dict.py in __getitem__(self, item)
     24         try:
---> 25             return super().__getitem__(item.ref())
     26         except AttributeError:

KeyError: <Reference wrapping <KerasTensor: shape=(8, 256, 64, 8) dtype=float32 (created by layer 'batch_normalization_1')>>


During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
    233             try:
--> 234                 output_mask = mask_map[node_output]
    235                 logging.debug('bottomed out at a model input')

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/_utils/tensor_dict.py in __getitem__(self, item)
     24         try:
---> 25             return super().__getitem__(item.ref())
     26         except AttributeError:

KeyError: <Reference wrapping <KerasTensor: shape=(8, 256, 64, 8) dtype=float32 (created by layer 'conv2d_4')>>


During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
    233             try:
--> 234                 output_mask = mask_map[node_output]
    235                 logging.debug('bottomed out at a model input')

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/_utils/tensor_dict.py in __getitem__(self, item)
     24         try:
---> 25             return super().__getitem__(item.ref())
     26         except AttributeError:

KeyError: <Reference wrapping <KerasTensor: shape=(8, 256, 64, 16) dtype=float32 (created by layer 'batch_normalization')>>


During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)

/tmp/ipykernel_24103/2513954395.py in <module>
     15 layer = model.layers[0]
     16 print(f"Layer {layer_idx}, filter {filter_idx}")
---> 17 model = delete_channels(model, layer, [filter_idx])

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/operations.py in delete_channels(model, layer, channels, node_indices, copy)
    103     surgeon = Surgeon(model, copy)
    104     surgeon.add_job('delete_channels', layer, node_indices=node_indices, channels=channels)
--> 105     return surgeon.operate()

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in operate(self)
    164             layer, node_index, tensor_index = output._keras_history
    165             output_nodes.append(layer.inbound_nodes[node_index])
--> 166         new_outputs, _ = self._rebuild_graph(self.model.inputs, output_nodes)
    167         new_model = Model(self.model.inputs, new_outputs)
    168 

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_graph(self, graph_inputs, output_nodes, graph_input_masks)
    276         # Call the recursive _rebuild_rec method to rebuild the submodel up to
    277         # each output layer
--> 278         outputs, output_masks = zip(*[_rebuild_rec(n) for n in output_nodes])
    279         return utils.single_element(outputs), output_masks
    280 

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in <listcomp>(.0)
    276         # Call the recursive _rebuild_rec method to rebuild the submodel up to
    277         # each output layer
--> 278         outputs, output_masks = zip(*[_rebuild_rec(n) for n in output_nodes])
    279         return utils.single_element(outputs), output_masks
    280 

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
    243                 # obtain its inputs and input masks
    244                 inputs, input_masks = zip(
--> 245                     *[_rebuild_rec(n) for n in inbound_nodes])
    246 
    247                 if all(i is None for i in inputs):

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in <listcomp>(.0)
    243                 # obtain its inputs and input masks
    244                 inputs, input_masks = zip(
--> 245                     *[_rebuild_rec(n) for n in inbound_nodes])
    246 
    247                 if all(i is None for i in inputs):

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
    243                 # obtain its inputs and input masks
    244                 inputs, input_masks = zip(
--> 245                     *[_rebuild_rec(n) for n in inbound_nodes])
    246 
    247                 if all(i is None for i in inputs):

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in <listcomp>(.0)
    243                 # obtain its inputs and input masks
    244                 inputs, input_masks = zip(
--> 245                     *[_rebuild_rec(n) for n in inbound_nodes])
    246 
    247                 if all(i is None for i in inputs):

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
    243                 # obtain its inputs and input masks
    244                 inputs, input_masks = zip(
--> 245                     *[_rebuild_rec(n) for n in inbound_nodes])
    246 
    247                 if all(i is None for i in inputs):

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in <listcomp>(.0)
    243                 # obtain its inputs and input masks
    244                 inputs, input_masks = zip(
--> 245                     *[_rebuild_rec(n) for n in inbound_nodes])
    246 
    247                 if all(i is None for i in inputs):

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
    266                         output = new_layer(utils.single_element(list(inputs)))
    267                 else:
--> 268                     new_layer, output_mask = self._apply_delete_mask(node, input_masks)
    269                     output = new_layer(utils.single_element(list(inputs)))
    270 

~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _apply_delete_mask(self, node, inbound_masks)
    570             # Get slice of mask with all singleton dimensions except
    571             # channels dimension
--> 572             index = [0] * (len(input_shape))
    573             assert len(layer.axis) == 1
    574             index[layer.axis[0]] = slice(None)

TypeError: object of type 'int' has no len()

Any ideas how to fix this? Or how to add support for BatchNorm layers?

MarvinKlemp avatar Sep 07 '21 11:09 MarvinKlemp

Hi, I also encountered this issue on TF 2.7, and was able to fix it by editing surgeon.py https://github.com/BenWhetton/keras-surgeon/blob/b0b892988e725b9203afc48e639c49d06155b59b/src/kerassurgeon/surgeon.py#L390 to just input_shape = node.input_shapes and it doesn't seem to break anything. If you still have issues, an alternative is to instead just add this line specifically for the BatchNorm layer in surgeon.py.

The issue was that node.input_shapes is a tuple e.g., (None, 32, 32, 64), so single_element returns the first element None (following https://github.com/BenWhetton/keras-surgeon/pull/58 which fixes other issues).

mchan133 avatar Dec 21 '21 16:12 mchan133

Hi @mchan133 , could you please help with #68 (continued from #6 )? Its memory issue when I run get_model_apoz .

nheelam avatar Jul 06 '22 01:07 nheelam