keras-surgeon
keras-surgeon copied to clipboard
Support for BatchNorm
Once a model contains BatchNorm layers, keras-surgeon fails.
Minimal code to reproduce the error:
#%%
import tensorflow as tf
from kerassurgeon.operations import delete_channels
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, kernel_size=(3,3), padding="same", strides=(1, 1), activation="relu"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(8, kernel_size=(3,3), padding="same", strides=(1, 1), activation="relu"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(2, kernel_size=(3,3), padding="same", strides=(1, 1), activation="relu"),
])
model.build((8, 256, 64, 2))
layer_idx = 0
filter_idx = 0
layer = model.layers[0]
model = delete_channels(model, layer, [filter_idx])
Stack Trace
KeyError Traceback (most recent call last)
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
233 try:
--> 234 output_mask = mask_map[node_output]
235 logging.debug('bottomed out at a model input')
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/_utils/tensor_dict.py in __getitem__(self, item)
24 try:
---> 25 return super().__getitem__(item.ref())
26 except AttributeError:
KeyError: <Reference wrapping <KerasTensor: shape=(8, 256, 64, 2) dtype=float32 (created by layer 'conv2d_5')>>
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
233 try:
--> 234 output_mask = mask_map[node_output]
235 logging.debug('bottomed out at a model input')
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/_utils/tensor_dict.py in __getitem__(self, item)
24 try:
---> 25 return super().__getitem__(item.ref())
26 except AttributeError:
KeyError: <Reference wrapping <KerasTensor: shape=(8, 256, 64, 8) dtype=float32 (created by layer 'batch_normalization_1')>>
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
233 try:
--> 234 output_mask = mask_map[node_output]
235 logging.debug('bottomed out at a model input')
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/_utils/tensor_dict.py in __getitem__(self, item)
24 try:
---> 25 return super().__getitem__(item.ref())
26 except AttributeError:
KeyError: <Reference wrapping <KerasTensor: shape=(8, 256, 64, 8) dtype=float32 (created by layer 'conv2d_4')>>
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
233 try:
--> 234 output_mask = mask_map[node_output]
235 logging.debug('bottomed out at a model input')
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/_utils/tensor_dict.py in __getitem__(self, item)
24 try:
---> 25 return super().__getitem__(item.ref())
26 except AttributeError:
KeyError: <Reference wrapping <KerasTensor: shape=(8, 256, 64, 16) dtype=float32 (created by layer 'batch_normalization')>>
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
/tmp/ipykernel_24103/2513954395.py in <module>
15 layer = model.layers[0]
16 print(f"Layer {layer_idx}, filter {filter_idx}")
---> 17 model = delete_channels(model, layer, [filter_idx])
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/operations.py in delete_channels(model, layer, channels, node_indices, copy)
103 surgeon = Surgeon(model, copy)
104 surgeon.add_job('delete_channels', layer, node_indices=node_indices, channels=channels)
--> 105 return surgeon.operate()
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in operate(self)
164 layer, node_index, tensor_index = output._keras_history
165 output_nodes.append(layer.inbound_nodes[node_index])
--> 166 new_outputs, _ = self._rebuild_graph(self.model.inputs, output_nodes)
167 new_model = Model(self.model.inputs, new_outputs)
168
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_graph(self, graph_inputs, output_nodes, graph_input_masks)
276 # Call the recursive _rebuild_rec method to rebuild the submodel up to
277 # each output layer
--> 278 outputs, output_masks = zip(*[_rebuild_rec(n) for n in output_nodes])
279 return utils.single_element(outputs), output_masks
280
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in <listcomp>(.0)
276 # Call the recursive _rebuild_rec method to rebuild the submodel up to
277 # each output layer
--> 278 outputs, output_masks = zip(*[_rebuild_rec(n) for n in output_nodes])
279 return utils.single_element(outputs), output_masks
280
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
243 # obtain its inputs and input masks
244 inputs, input_masks = zip(
--> 245 *[_rebuild_rec(n) for n in inbound_nodes])
246
247 if all(i is None for i in inputs):
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in <listcomp>(.0)
243 # obtain its inputs and input masks
244 inputs, input_masks = zip(
--> 245 *[_rebuild_rec(n) for n in inbound_nodes])
246
247 if all(i is None for i in inputs):
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
243 # obtain its inputs and input masks
244 inputs, input_masks = zip(
--> 245 *[_rebuild_rec(n) for n in inbound_nodes])
246
247 if all(i is None for i in inputs):
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in <listcomp>(.0)
243 # obtain its inputs and input masks
244 inputs, input_masks = zip(
--> 245 *[_rebuild_rec(n) for n in inbound_nodes])
246
247 if all(i is None for i in inputs):
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
243 # obtain its inputs and input masks
244 inputs, input_masks = zip(
--> 245 *[_rebuild_rec(n) for n in inbound_nodes])
246
247 if all(i is None for i in inputs):
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in <listcomp>(.0)
243 # obtain its inputs and input masks
244 inputs, input_masks = zip(
--> 245 *[_rebuild_rec(n) for n in inbound_nodes])
246
247 if all(i is None for i in inputs):
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _rebuild_rec(node)
266 output = new_layer(utils.single_element(list(inputs)))
267 else:
--> 268 new_layer, output_mask = self._apply_delete_mask(node, input_masks)
269 output = new_layer(utils.single_element(list(inputs)))
270
~/anaconda3/envs/tf38/lib/python3.8/site-packages/kerassurgeon/surgeon.py in _apply_delete_mask(self, node, inbound_masks)
570 # Get slice of mask with all singleton dimensions except
571 # channels dimension
--> 572 index = [0] * (len(input_shape))
573 assert len(layer.axis) == 1
574 index[layer.axis[0]] = slice(None)
TypeError: object of type 'int' has no len()
Any ideas how to fix this? Or how to add support for BatchNorm
layers?
Hi, I also encountered this issue on TF 2.7, and was able to fix it by editing surgeon.py
https://github.com/BenWhetton/keras-surgeon/blob/b0b892988e725b9203afc48e639c49d06155b59b/src/kerassurgeon/surgeon.py#L390
to just input_shape = node.input_shapes
and it doesn't seem to break anything. If you still have issues, an alternative is to instead just add this line specifically for the BatchNorm layer in surgeon.py.
The issue was that node.input_shapes
is a tuple e.g., (None, 32, 32, 64)
, so single_element
returns the first element None
(following https://github.com/BenWhetton/keras-surgeon/pull/58 which fixes other issues).
Hi @mchan133 , could you please help with #68 (continued from #6 )? Its memory issue when I run get_model_apoz .