keras-io icon indicating copy to clipboard operation
keras-io copied to clipboard

Request for tutorial on how to convert a channel first Keras model to channel last model?

Open mrtpk123 opened this issue 4 years ago • 3 comments
trafficstars

Hello,

I have a pre-trained Keras model (in h5 format) where all the layers operate on channel first data format. I want to convert this model to operate on the channel last data format (the default data format).

For some clarity, the current model summary looks like this:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input0 (InputLayer)             [(None, 3, 240, 320) 0
__________________________________________________________________________________________________
259_pad (ZeroPadding2D)         (None, 3, 242, 322)  0           input0[0][0]
__________________________________________________________________________________________________
259 (Conv2D)                    (None, 16, 120, 160) 432         259_pad[0][0]
__________________________________________________________________________________________________
260 (BatchNormalization)        (None, 16, 120, 160) 64          259[0][0]
__________________________________________________________________________________________________
261 (Activation)                (None, 16, 120, 160) 0           260[0][0]
__________________________________________________________________________________________________
262_pad (ZeroPadding2D)         (None, 16, 122, 162) 0           261[0][0]
__________________________________________________________________________________________________
262 (DepthwiseConv2D)           (None, 16, 120, 160) 144         262_pad[0][0]

As you can see, it's in the channel first format. I want to convert each layer in the model to operate on channel last format. So the ideal model summary will be as follows:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input0 (InputLayer)             [(None, 240, 320, 3) 0
__________________________________________________________________________________________________
259_pad (ZeroPadding2D)         (None, 242, 322, 3)  0           input0[0][0]
__________________________________________________________________________________________________
259 (Conv2D)                    (None, 120, 160, 16) 432         259_pad[0][0]
__________________________________________________________________________________________________
260 (BatchNormalization)        (None, 120, 160, 16) 64          259[0][0]
__________________________________________________________________________________________________
261 (Activation)                (None, 120, 160, 16) 0           260[0][0]
__________________________________________________________________________________________________
262_pad (ZeroPadding2D)         (None, 122, 162, 16) 0           261[0][0]
__________________________________________________________________________________________________
262 (DepthwiseConv2D)           (None, 120, 160, 16) 144         262_pad[0][0]

I raised an issue at TensorFlow and from a bit of searching I understood that this situation is not uncommon.

From this comment, I understood that we have to do a "network surgery" to accomplish this.

The above solution works when the training files were available. In the case where I had only the output of tf.estimator.export_savedmodel, I had to do a network surgery. I made a clone of the graph, in channel_last format, loaded it in and assigned all variables, values from the trained model. This new graph runs fine on CPU.

It would be great if you could provide a sample code on how to do this. Any help/pointers are really appreciated.

Thank you.

mrtpk123 avatar May 14 '21 18:05 mrtpk123