StyleGAN2-Tensorflow-2.0 icon indicating copy to clipboard operation
StyleGAN2-Tensorflow-2.0 copied to clipboard

conv2d_mod/Conv2D NCHW not implemented

Open xiaoliangbai opened this issue 4 years ago • 2 comments

generated_images = self.GAN.GM.predict(n1 + [n2], batch_size = BATCH_SIZE)

File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 909, in predict use_multiprocessing=use_multiprocessing) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 722, in predict callbacks=callbacks) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 393, in model_iteration batch_outs = f(ins_batch) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py", line 3740, in call outputs = self._graph_fn(*converted_inputs) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1081, in call return self._call_impl(args, kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1121, in _call_impl return self._call_flat(args, self.captured_inputs, cancellation_manager) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat ctx, args, cancellation_manager=cancellation_manager) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 511, in call ctx=ctx) File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute six.raise_from(core._status_to_exception(e.code, message), None) File "", line 3, in raise_from tensorflow.python.framework.errors_impl.UnimplementedError: The Conv2D op currently only supports the NHWC tensor format on the CPU. The op was given the format: NCHW [[node model_1/conv2d_mod/Conv2D (defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1751) ]] [Op:__inference_keras_scratch_graph_11413]

Function call stack: keras_scratch_graph

Seems conv2d does not take NCHW data format. I tried to force to run on gpu (with tf.device('/gpu:1'):...), it did not work. I also tried different tf versions (2.0, 2.3), even with docker image for tf2.0, all got into the same issue.

Anyone knows how to get around this issue? Thanks

xiaoliangbai avatar Oct 19 '20 00:10 xiaoliangbai

It is because it runs on CPU, try batch_size = 1, and in conv_mod.py :

# add this
x = tf.transpose(x, [0, 2, 3, 1])

# change NCHW to NHWC
x = tf.nn.conv2d(x, w, strides=self.strides, padding="SAME", data_format="NHWC")

# add this
x = tf.transpose(x, [0, 3, 1, 2])

anthonyivol avatar Oct 19 '20 15:10 anthonyivol

Thanks Anthony, your solution works. I thought weights also need to transpose axis in_chan to match with activation data format, turns out it doesn't.

xiaoliangbai avatar Oct 19 '20 20:10 xiaoliangbai