keras-unet-collection icon indicating copy to clipboard operation
keras-unet-collection copied to clipboard

ValueError: Shapes (739, 128, 128, 3) and (128, 128, 3) are incompatible

Open alqurri77 opened this issue 2 years ago • 1 comments

I got below error when tried to use "transunet_2d"

model = models.transunet_2d((128, 128, 3), filter_num=[64, 128, 256, 512], n_labels=3, stack_num_down=2, stack_num_up=2, embed_dim=768, num_mlp=3072, num_heads=12, num_transformer=12, activation='ReLU', mlp_activation='GELU', output_activation='Softmax', batch_norm=True, pool=True, unpool='bilinear', name='transunet')

See below error:

24/24 [==============================] - 1975s 82s/step

ValueError Traceback (most recent call last) in 14 temp_out = model.predict([valid_input]) 15 y_pred = temp_out[-1] ---> 16 record = np.mean(keras.losses.categorical_crossentropy(valid_target, y_pred)) 17 print('\tInitial loss = {}'.format(record)) 18 print("step1")

~/.local/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.traceback) --> 153 raise e.with_traceback(filtered_tb) from None 154 finally: 155 del filtered_tb

~/.local/lib/python3.8/site-packages/keras/losses.py in categorical_crossentropy(y_true, y_pred, from_logits, label_smoothing, axis) 1785 lambda: y_true) 1786 -> 1787 return backend.categorical_crossentropy( 1788 y_true, y_pred, from_logits=from_logits, axis=axis) 1789

~/.local/lib/python3.8/site-packages/keras/backend.py in categorical_crossentropy(target, output, from_logits, axis) 5117 target = tf.convert_to_tensor(target) 5118 output = tf.convert_to_tensor(output) -> 5119 target.shape.assert_is_compatible_with(output.shape) 5120 5121 # Use logits whenever they are available. softmax and sigmoid

ValueError: Shapes (739, 128, 128, 3) and (128, 128, 3) are incompatible

alqurri77 avatar May 30 '22 22:05 alqurri77

try expanding the dimension of valid_input

valid_input = np.expand_dims(valid_input, axis=0) temp_out = model.predict([valid_input])

akaiml avatar Jun 04 '22 09:06 akaiml