Keras-RFCN
Keras-RFCN copied to clipboard
mrcnn_class_loss function/dataset problem
I tried to train a simple model with your code, but unfortunately I get the following error when I try to train:
InvalidArgumentError (see above for traceback): assertion failed: [] [Condition x == y did not hold element-wise:] [x (mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/Shape_1:0) = ] [4 16] [y (mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/strided_slice:0) = ] [1 64]
[[Node: mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/assert_equal/Assert/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_INT32, DT_STRING, DT_INT32], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/assert_equal/All/_4229, mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/assert_equal/Assert/Assert/data_0, mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/assert_equal/Assert/Assert/data_1, mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/assert_equal/Assert/Assert/data_2, mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/Shape/_4231, mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/assert_equal/Assert/Assert/data_4, mrcnn_class_loss/SparseSoftmaxCrossEntropyWithLogits/strided_slice/_4233)]]
As I see the problem is with the tensor shapes, so it can happen that there is a problem with my dataset.
- When I call
dataset.load_image(...)
it returns an image with shape(128, 128, 3)
- When I call
dataset.load_bbox(...)
it returns a list of bounding boxes with shape:(nb_of_bboxes, 4)
and a list of class ids to the corresponding bboxes
I hope you encountered with the same problem and can help me with this.
I think the problem is with the VotePooling
.
Instead of pooled = tf.expand_dims(pooled, 0)
the right output is pooled = tf.reshape(pooled, (self.batch_size, self.num_rois, self.channel_num))
as your comment says that the returned tensor should have a shape of (batch, num_rois, class_num)
Could you confirm?
I only tested the model under batch_size=1 It seems that the model goes wrong when bs>1
Bug labeled.