keras-unet-collection icon indicating copy to clipboard operation
keras-unet-collection copied to clipboard

ValueError: Input 0 is incompatible with layer model_1: expected shape=(None, 512, 512, 3), found shape=(None, 512, 512, 64)

Open akaiml opened this issue 2 years ago • 2 comments

I am getting this error when I added custom loss

input shape(512,512,3) model=models.unet_2d(input_size, filter_num, n_labels=1, stack_num_down=2, stack_num_up=2, activation='ReLU', output_activation=None, batch_norm=True, pool='max', unpool='nearest', backbone='VGG16', weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet')

//////// selected_layers = ['block1_conv1', 'block2_conv2',"block3_conv3" ] selected_layer_weights = [0.65, 0.3 , 0.05 ]

vgg = tf.keras.applications.VGG16(weights='imagenet', include_top=False, input_shape=input_size) vgg.trainable = False outputs = [vgg.get_layer(l).output for l in selected_layers] model = tf.keras.Model(vgg.input, outputs)

@tf.function def perceptual_loss(input_image , reconstruct_image): h1_list = model(input_image) h2_list = model(reconstruct_image)

rc_loss = 0.0
for h1, h2, weight in zip(h1_list, h2_list, selected_layer_weights):
    h1 = K.batch_flatten(h1)
    h2 = K.batch_flatten(h2)
    rc_loss = rc_loss + weight * K.sum(K.square(h1 - h2), axis=-1)

return rc_loss

/////// model.compile(loss=perceptual_loss, optimizer=keras.optimizers.SGD(lr=1e-2),metrics=[tf.keras.metrics.MeanAbsoluteError(),tf.keras.metrics.MeanSquaredError()])

pretrained weight='imagenet' Backbone=VGG16

How to resolve this

akaiml avatar Apr 21 '22 19:04 akaiml

have you solved this? I have a similar issue.

iMilchshake avatar May 25 '22 09:05 iMilchshake

I was trying to evaluate my network using perceptual loss. The issue was because my model output shape was (512,512,1). Hence when the VGG layer takes in the input with pre-trained weight "Imagenet" and it looks for a 3 channel image. Hence before providing the "reconstructed image" to the function I made sure that it becomes a 3 channel. reconstruct_image = tf.keras.layers.Concatenate()([reconstruct_image, reconstruct_image, reconstruct_image])

I have modified the code slightly. Also, I had an OOM issue, for that, I have reduced my image size to (256,256). My modified code is given below

""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" @tf.function def perceptual_loss(y_true,y_pred): output=[] # print("image_true",y_true.shape) # print("image_pred",y_pred.shape) # selected_layers = ['block1_conv1', 'block2_conv1','block3_conv1' ] selected_layer_weights = [0.65, 0.3, 0.05] #Define VGG model # print("here")

vgg=tf.keras.applications.VGG16(include_top=False,weights='imagenet',input_shape=[256,256,3])
# print(vgg.input.shape)

inputs =[vgg.input]
outputs = [vgg.get_layer('block1_conv1').output]
layer1=tf.keras.Model(inputs=inputs,outputs=outputs)
layer1.trainable=False
# print(layer1.input_shape)

# print((layer1.output))
outputs = [vgg.get_layer('block2_conv1').output]
#
layer2=tf.keras.Model(inputs=inputs,outputs=outputs)
layer2.trainable=False
# print((layer2.output))

#
outputs =[vgg.get_layer('block3_conv1').output]

layer3=tf.keras.Model(inputs=inputs,outputs=outputs)
layer3.trainable=False
# print((layer3.output))

# y_true = tf.reshape(tf.squeeze(y_true), [-1, 512, 512, 3])
# y_true = tf.keras.layers.Concatenate()([y_true, y_true, y_true])
# y_true1=model_vgg(y_true)
# y_pred1=model_vgg(y_pred)
# y_pred =tf.squeeze(y_pred, axis=0)
if y_true.shape[1:]==[256, 256, 1]:
    y_true = tf.keras.layers.Concatenate()([y_true, y_true, y_true])
if y_pred.shape[1:]==[256, 256, 1]:
    y_pred = tf.keras.layers.Concatenate()([y_pred, y_pred, y_pred])

# print("new shape",y_pred.shape)

perp_loss =(selected_layer_weights[0] * (K.mean(K.square(layer1(y_true) - layer1(y_pred)))) +
                         selected_layer_weights[1] * K.mean(
            K.square(layer2(y_true) - layer2(y_pred))) + selected_layer_weights[2] * K.mean(
            K.square(layer3(y_true) - layer3(y_pred))))
return perp_loss

""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

akaiml avatar May 26 '22 10:05 akaiml