StyleGAN2-Tensorflow-2.0
StyleGAN2-Tensorflow-2.0 copied to clipboard
conv2d_mod/Conv2D NCHW not implemented
generated_images = self.GAN.GM.predict(n1 + [n2], batch_size = BATCH_SIZE)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 909, in predict
use_multiprocessing=use_multiprocessing)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 722, in predict
callbacks=callbacks)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 393, in model_iteration
batch_outs = f(ins_batch)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py", line 3740, in call
outputs = self._graph_fn(*converted_inputs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1081, in call
return self._call_impl(args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1121, in _call_impl
return self._call_flat(args, self.captured_inputs, cancellation_manager)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
ctx, args, cancellation_manager=cancellation_manager)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 511, in call
ctx=ctx)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "
Function call stack: keras_scratch_graph
Seems conv2d does not take NCHW data format. I tried to force to run on gpu (with tf.device('/gpu:1'):...), it did not work. I also tried different tf versions (2.0, 2.3), even with docker image for tf2.0, all got into the same issue.
Anyone knows how to get around this issue? Thanks
It is because it runs on CPU, try batch_size = 1, and in conv_mod.py :
# add this
x = tf.transpose(x, [0, 2, 3, 1])
# change NCHW to NHWC
x = tf.nn.conv2d(x, w, strides=self.strides, padding="SAME", data_format="NHWC")
# add this
x = tf.transpose(x, [0, 3, 1, 2])
Thanks Anthony, your solution works. I thought weights also need to transpose axis in_chan to match with activation data format, turns out it doesn't.