pytorch2keras icon indicating copy to clipboard operation
pytorch2keras copied to clipboard

Problem with converting BatchNorm layer

Open Golbstein opened this issue 6 years ago • 3 comments

I've added batchnorm layer to your dummy example and it crushed

the code:

from keras import backend as K
from pytorch2keras.converter import pytorch_to_keras
import torch 
import torch.nn as nn
import numpy as np
from torch.autograd import Variable

class TestConv2d(nn.Module):
    def __init__(self, inp=10, out=16, kernel_size=3):
        super(TestConv2d, self).__init__()
        self.conv2d = nn.Conv2d(inp, out, stride=1, kernel_size=kernel_size, bias=True)
        self.bn_1   = nn.BatchNorm2d(num_features = out)
    def forward(self, x):
        x = self.conv2d(x)
        x = self.bn_1(x)
        return x

model = TestConv2d()
input_np = np.random.uniform(0, 1, (1, 10, 32, 32))
input_var = Variable(torch.FloatTensor(input_np))

# we should specify shape of the input tensor
k_model = pytorch_to_keras(model, input_var, [(10, 32, 32,)], verbose=True)  

the error:

graph(%0 : Float(1, 10, 32, 32)
      %1 : Float(16, 10, 3, 3)
      %2 : Float(16)
      %3 : Float(16)
      %4 : Float(16)
      %5 : Float(16)
      %6 : Float(16)
      %7 : Long()) {
  %8 : Float(1, 16, 30, 30) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%0, %1, %2), scope: TestConv2d/Conv2d[conv2d]
  %9 : Float(1, 16, 30, 30) = onnx::BatchNormalization[epsilon=1e-05, is_test=1, momentum=1](%8, %3, %4, %5, %6), scope: TestConv2d/BatchNorm2d[bn_1]
  return (%9);

Graph inputs: ['0', '1', '2', '3', '4', '5', '6', '7']
Graph outputs: ['9']
State dict: ['conv2d.weight', 'conv2d.bias', 'bn_1.weight', 'bn_1.bias', 'bn_1.running_mean', 'bn_1.running_var', 'bn_1.num_batches_tracked']
graph node: TestConv2d/Conv2d[conv2d]
node id: 8
type: onnx::Conv
inputs: ['0', '1', '2']
outputs: ['TestConv2d/Conv2d[conv2d]']
name in state_dict: conv2d
attrs: {'dilations': [1, 1], 'group': 1, 'kernel_shape': [3, 3], 'pads': [0, 0, 0, 0], 'strides': [1, 1]}
is_terminal: False
Converting convolution ...
graph node: TestConv2d/BatchNorm2d[bn_1]
node id: 9
type: onnx::BatchNormalization
inputs: ['8', '3', '4', '5', '6']
outputs: ['TestConv2d/BatchNorm2d[bn_1]']
name in state_dict: bn_1
attrs: {'epsilon': 1e-05, 'is_test': 1, 'momentum': 1.0}
is_terminal: True
Converting batchnorm ...
InvalidArgumentError                      Traceback (most recent call last)
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ in _create_c_op(graph, node_def, inputs, control_inputs)
   1575   try:
-> 1576     c_op = c_api.TF_FinishOperation(op_desc)
   1577   except errors.InvalidArgumentError as e:

InvalidArgumentError: Shape must be rank 1 but is rank 0 for 'bn_10.9351925092536912/cond/Reshape_4' (op: 'Reshape') with input shapes: [1,16,1,1], [].

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-1-7f985261f348> in <module>()
     25 # we should specify shape of the input tensor
---> 26 k_model = pytorch_to_keras(model, input_var, [(10, 32, 32,)], verbose=True)

~\Anaconda3\lib\site-packages\pytorch2keras\ in pytorch_to_keras(model, args, input_shapes, change_ordering, training, verbose, names)
    313             node_input_names,
    314             layers, state_dict,
--> 315             names
    316         )
    317         if node_id in graph_outputs:

~\Anaconda3\lib\site-packages\pytorch2keras\ in convert_batchnorm(params, w_name, scope_name, inputs, layers, weights, names)
     59             name=tf_name
     60         )
---> 61     layers[scope_name] = bn(layers[inputs[0]])

~\Anaconda3\lib\site-packages\keras\engine\ in __call__(self, inputs, **kwargs)
    455             # Actually call the layer,
    456             # collecting output(s), mask(s), and shape(s).
--> 457             output =, **kwargs)
    458             output_mask = self.compute_mask(inputs, previous_mask)

~\Anaconda3\lib\site-packages\keras\layers\ in call(self, inputs, training)
    204         return K.in_train_phase(normed_training,
    205                                 normalize_inference,
--> 206                                 training=training)
    208     def get_config(self):

~\Anaconda3\lib\site-packages\keras\backend\ in in_train_phase(x, alt, training)
   3122     # else: assume learning phase is a placeholder tensor.
-> 3123     x = switch(training, x, alt)
   3124     if uses_learning_phase:
   3125         x._uses_learning_phase = True

~\Anaconda3\lib\site-packages\keras\backend\ in switch(condition, then_expression, else_expression)
   3056         x = tf.cond(condition,
   3057                     then_expression_fn,
-> 3058                     else_expression_fn)
   3059     else:
   3060         # tf.where needs its condition tensor

~\Anaconda3\lib\site-packages\tensorflow\python\util\ in new_func(*args, **kwargs)
    452                 'in a future version' if date is None else ('after %s' % date),
    453                 instructions)
--> 454       return func(*args, **kwargs)
    455     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    456                                        _add_deprecated_arg_notice_to_docstring(

~\Anaconda3\lib\site-packages\tensorflow\python\ops\ in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
   2055     context_f = CondContext(pred, pivot_2, branch=0)
   2056     context_f.Enter()
-> 2057     orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
   2058     if orig_res_f is None:
   2059       raise ValueError("false_fn must have a return value.")

~\Anaconda3\lib\site-packages\tensorflow\python\ops\ in BuildCondBranch(self, fn)
   1893     """Add the subgraph defined by fn() to the graph."""
   1894     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
-> 1895     original_result = fn()
   1896     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1897     if len(post_summaries) > len(pre_summaries):

~\Anaconda3\lib\site-packages\keras\layers\ in normalize_inference()
    165                     broadcast_gamma,
    166                     axis=self.axis,
--> 167                     epsilon=self.epsilon)
    168             else:
    169                 return K.batch_normalization(

~\Anaconda3\lib\site-packages\keras\backend\ in batch_normalization(x, mean, var, beta, gamma, axis, epsilon)
   1906             # so it may have extra axes with 1, it is not needed and should be removed
   1907             if ndim(mean) > 1:
-> 1908                 mean = tf.reshape(mean, (-1))
   1909             if ndim(var) > 1:
   1910                 var = tf.reshape(var, (-1))

~\Anaconda3\lib\site-packages\tensorflow\python\ops\ in reshape(tensor, shape, name)
   7432   if _ctx is None or not _ctx._eager_context.is_eager:
   7433     _, _, _op = _op_def_lib._apply_op_helper(
-> 7434         "Reshape", tensor=tensor, shape=shape, name=name)
   7435     _result = _op.outputs[:]
   7436     _inputs_flat = _op.inputs

~\Anaconda3\lib\site-packages\tensorflow\python\framework\ in _apply_op_helper(self, op_type_name, name, **keywords)
    785         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    786                          input_types=input_types, attrs=attr_protos,
--> 787                          op_def=op_def)
    788       return output_structure, op_def.is_stateful, op

~\Anaconda3\lib\site-packages\tensorflow\python\util\ in new_func(*args, **kwargs)
    452                 'in a future version' if date is None else ('after %s' % date),
    453                 instructions)
--> 454       return func(*args, **kwargs)
    455     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    456                                        _add_deprecated_arg_notice_to_docstring(

~\Anaconda3\lib\site-packages\tensorflow\python\framework\ in create_op(***failed resolving arguments***)
   3153           input_types=input_types,
   3154           original_op=self._default_original_op,
-> 3155           op_def=op_def)
   3156       self._create_op_helper(ret, compute_device=compute_device)
   3157     return ret

~\Anaconda3\lib\site-packages\tensorflow\python\framework\ in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1729           op_def, inputs, node_def.attr)
   1730       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1731                                 control_input_ops)
   1733     # Initialize self._outputs.

~\Anaconda3\lib\site-packages\tensorflow\python\framework\ in _create_c_op(graph, node_def, inputs, control_inputs)
   1577   except errors.InvalidArgumentError as e:
   1578     # Convert to ValueError for backwards compatibility.
-> 1579     raise ValueError(str(e))
   1581   return c_op

ValueError: Shape must be rank 1 but is rank 0 for 'bn_10.9351925092536912/cond/Reshape_4' (op: 'Reshape') with input shapes: [1,16,1,1], [].

Golbstein avatar Feb 11 '19 13:02 Golbstein

got the same error, upgraded to tensorflow-1.13.1 and it goes away

ShuangLiu1992 avatar May 09 '19 10:05 ShuangLiu1992


JaeDukSeo avatar Jul 08 '19 00:07 JaeDukSeo

Can you tell me yours version of the following libs:

thank you


houweidong avatar Dec 26 '19 07:12 houweidong