Mask_RCNN
Mask_RCNN copied to clipboard
Mobile Mask RCNN
Actual commit is bf533ea.
Images show inference on GTX 1080.
Seems to have a memory leak. Haven't focused on that yet. EDIT: Memory leak may be a bug in tensorflow https://github.com/tensorflow/tensorflow/issues/13359
Let me know how you'd like to deal with the model I produced.
Thanks! Cory
@Cpruce I went through the commit superficially, just wanted to ask if you have trained the mobilenet with pretrained-imagenet weights or not and If so can you add the fucntionality to use it using --model=imagenet flag.
You can check my last commit. I haven't tested the addition to get_imagenet_weights out yet.
TODO: Something to change is "mobilenet" -> "mobilenet224", etc. EDIT: Done
I just checked it, It should be mobilenet_1_0_224_tf_no_top.h5 that is to be used. "No top file" does not include the final layers like Global Avg pooling/Final Fc/Softmax. link for mobilenet_1_0_224_tf_no_top.h5 you can replace the path with this link.
One question is there any specific reason for using this version of mobilenet?
@ps48 Yup, I've got that file. I finetuned on that to bootstrap and have been using checkpoints I've produced since. Wasn't sure which one I had got working since the one I listed was the most recent one I downloaded, which is why I mentioned I didn't test that yet :) You're probably right but we may be able to use both when we exclude layers (in this case the final dense network).
I chose the larger input sized model because I wanted to get mobilenet to be on par with the resnet50 backbone. Not surprisingly, it's proving to be non-trivial :D
Edit I think it might not matter actually, as long as you continue to use by_name=True when loading weights.
def load_weights(self, filepath, by_name=False,
skip_mismatch=False, reshape=False):
....
If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.
@ps48 I can confirm that finetuning on imagenet with the mobilenet model including top does work. Could you try it out please?
@Cprune: your mAP for bbox is better than the last version of repo (0.37 / 0.34) for resnet 101. Could you also show the mAP for segmentation? The repo achieved 0.29. And if it is possible please share to us your pretrained of resnet 101 on coco
@John1231983 I've added mobilenet224 as a backbone, as opposed to resnet101. (If you're comparing my mobilenet v a resnet101 backbone, then wow! I think resnet101 could do better). AFAIK, you should be able to use resnet50 and resnet101 pretrained weights interchangeably as it seems the code supports it (ie, layers_regex). Unfortunately my mAP isn't on par with the original backbone, resnet50. See Inference Performance in README.md at https://github.com/Cpruce/Mask_RCNN. It is much faster and smaller though. The inference performance results are bbox in that image, but segm is extremely close, so the picture is representative. You can also see some coco test images on the README. However, segmentation and bounding box results were poor on the shapes images. I believe this is due to either overfitting from coco or not finetuning on shapes enough.
Good job @Cpruce : Could you please share your pre-train coco using resnet101? Can I use it as pretrain for another dataset for instance segmentation I am very interesting why it can provide better than matterport performance(0.34 on bbox), compared with 0.37 as you.
@John1231983 Do you mean mobilenet224? I haven't tried resnet101 as a backbone. I'm also not sure where you're getting the numbers from. These are the results for mobilenet224 after 500 something epochs:
And these are the results of resnet50:
Not sure exactly how many epochs but probably much less than 500
Sorry. I means resnet50(second image). Have you give me your pretrain resnet 50 on coco ?
@John1231983 https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5
@Cpruce : So, you just use the matterport pretrain model and it give 0.37 AP. Am I right? Because the report of matterport shows 0.34
@John1231983 I fine-tuned the model a bit on coco. Are you able to do the same? ie do you have a gpu?
Yes. I have 1 titanx pascal 12gb. So if I understand correctly, you start to train coco with imagenet resnet50 pretrained. You also finetune config. After training, you got a coco pretrain and the 0.37 is from it. If it is correct, I am very happy to see your coco pretrain because I would like to use it (instead of matterport's coco pretrain as you gave me).
@John1231983 With the model posted above and a clean checkout of master, I consistently produce 0.37 for the mAP at the first state. Could you post what results you get from running evaluate on that model?
On a tangent, my mobilenet backbone is not as fast as I thought and a decent amount was due to the gpus. There's still a significant increase in speed, though it is still rather slow and the scores could be better. You can check the readme now for the update
@Cpruce: I have not enough GPU to train coco. I just use the author pretrained coco. It provided 0.29 ap in seg task. Could you provide us your pretrain coco that shown 0.37 in bbox?
@John1231983 I produced the 0.37 mAP from the base model in the release, no fine-tuning. I think you should also be able to train Coco on a Titan x as I am able to on a 1080 and a 1080 Ti, individually. Let's move the conversation elsewhere. Could you please email me your keras and tensorflow versions?
@Cpruce Thanks! This is a great addition. I'm starting to review the code.
In keeping with the goal of keeping the repo small and light, I'd suggest removing anything that can be removed. The things that caught my eye right away are:
- Validation results can probably be included in the Releases page like we did with the previous release. Or, alternatively, I'm thinking of adding another page, other than the main README, that covers more details.
- The sample images can either be moved to the details page if I add it (still thinking about that), or they can be included in a separate blog post or an external resource. And then, I think they're only useful if we show the same image being segmented with two different backbones next to each other so it's easy to compare.
- Training loss curves are not essential and can be removed.
-
supported_architectures.py
. Doesn't have enough code to justify being a separate file.
I'm yet to go through the code, though. I'll let you know if I have any questions.
The image preprogress of MobileNet is different from ResNet101, but you dont change this in your code. In the function "mold image",you still use the mean pixel for resnet101??I think this will influence the performance of network.
@changzhonghan Mean pixel values are artifacts of the imagenet dataset https://github.com/matterport/Mask_RCNN/issues/102
https://www.learnopencv.com/keras-tutorial-using-pre-trained-imagenet-models/
The preprocess_input function in mobilenet is differenr from resnet.The preprocess of Resnet uses the mean pixel of imagenet.But thr preprocess of mobilenet is (x/122.5-1.0),which dont uses the mean pixel.
@waleedka Thanks for the great repository and your effort! I will make your requested changes after you are done reviewing. Note that only a few commit hunks are to be committed. Somethings like the notebook I found to be broken with the addition of ARCH in config. I was planning on going through it soon to make sure everything works. Please feel free to reach out if you want to discuss anything.
@changzhonghan I did a quick test with one of my lower checkpoints for a few epochs, but the loss kept increasing. That said, this is not a completely fair test.
I found that the keras mobilenet does use mean pixel values of imagenet, like resnet50.
def preprocess_input(x):
"""Preprocesses a numpy array encoding a batch of images.
# Arguments
x: a 4D numpy array consists of RGB values within [0, 255].
# Returns
Preprocessed array.
"""
return imagenet_utils.preprocess_input(x, mode='tf')
which leads to
def _preprocess_numpy_input(x, data_format, mode):
"""Preprocesses a Numpy array encoding a batch of images.
# Arguments
x: Input array, 3D or 4D.
data_format: Data format of the image array.
mode: One of "caffe", "tf" or "torch".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
without scaling.
- tf: will scale pixels between -1 and 1,
sample-wise.
- torch: will scale pixels between 0 and 1 and then
will normalize each channel with respect to the
ImageNet dataset.
# Returns
Preprocessed Numpy array.
"""
if mode == 'tf':
x /= 127.5
x -= 1.
return x
if mode == 'torch':
x /= 255.
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
else:
if data_format == 'channels_first':
# 'RGB'->'BGR'
if x.ndim == 3:
x = x[::-1, ...]
else:
x = x[:, ::-1, ...]
else:
# 'RGB'->'BGR'
x = x[..., ::-1]
mean = [103.939, 116.779, 123.68]
std = None
# Zero-center by mean pixel
if data_format == 'channels_first':
if x.ndim == 3:
x[0, :, :] -= mean[0]
x[1, :, :] -= mean[1]
x[2, :, :] -= mean[2]
if std is not None:
x[0, :, :] /= std[0]
x[1, :, :] /= std[1]
x[2, :, :] /= std[2]
else:
x[:, 0, :, :] -= mean[0]
x[:, 1, :, :] -= mean[1]
x[:, 2, :, :] -= mean[2]
if std is not None:
x[:, 0, :, :] /= std[0]
x[:, 1, :, :] /= std[1]
x[:, 2, :, :] /= std[2]
else:
x[..., 0] -= mean[0]
x[..., 1] -= mean[1]
x[..., 2] -= mean[2]
if std is not None:
x[..., 0] /= std[0]
x[..., 1] /= std[1]
x[..., 2] /= std[2]
return x
Can you post what you're referencing?
This is something wrong in my previous comment. The preprocess of mobilenet is (x/127.5-1.0),which dont uses the mean pixel of imagenet.In Keras mobilenet, they use [ imagenet_utils.preprocess_input(x, mode='tf')] which return : if mode == 'tf': x /= 127.5 x -= 1. return x This function is different from resnet.
You're right. I'll test whether this performs better:
def mold_image(images, config):
"""Takes RGB images with 0-255 values and subtraces
the mean pixel and converts it to float. Expects image
colors in RGB order.
"""
return images.astype(np.float32)/127.5 - 1.0 #images.astype(np.float32) - config.MEAN_PIXEL
def unmold_image(normalized_images, config):
"""Takes a image normalized with mold() and returns the original."""
return ((normalized_images + 1)*127.5).astype(np.uint8) #(normalized_images + config.MEAN_PIXEL).astype(np.uint8)
@Cpruce Is there an advantage to include stage 13 (block 13) of the mobilenet base in your architecture? As can be seen here. The spatial dimensions of block 12 are consistent with res-nets requirements. In my experiments, I used block 12 as my C5
output.
@JonathanCMitchell I believe the advantage would be another block (more parameters) for the network to work with. Except for stage 1, I tried to remain as consistent as possible with the resnet backbone, extracting layer feature maps at the end of each pointwise convolution output shape round (64, 128, 256, 512, 1024). However, I can't say I've tried with C5 at block 12 yet. Have you tried with C5 at block 13 and have you been able to produce results yet?
@changzhonghan So far, so good. Thank you for spotting that preprocessing discrepancy. The model with the above normalization has been converging faster
@Cpruce It appears that you are using trainable Batch Normalization layers in your MobileNet graph, whereas the current master branch implementation sets the BatchNorm layers to non trainable
. You can see this where they overwrite the layer with BatchNorm
on 56
Also, I believe you can convert the fpn_mask_graph inside build_fpn_mask_graph
L956 to use DepthwiseSeparable layers. Did you consider this option?
@JonathanCMitchell You're right. Thanks, I am testing BatchNorm
now as my batch size is small. There is a noticeable improvement when training (<1.2 training loss and decreasing). I'll experiment with the Depthwise separable convolution after I return/BatchNorm change has run for a number of epochs.
@JonathanCMitchell have you tried using Depthwise separable layers before on FPN? Can you help me understand how will it be different than TimeDistributed layer. If I understand correctly, even in the time distributed layer the convolution is done on sliced time-dimension.
@ps48 The TimeDistributed layer is a means of performing the 2D convolution on each feature map from the feature pyramid network. So if the FPN contains four stages (C2, C3, C4, C5), the TimeDistributed layer is a means of looping through each stage and performing a 2D convolution on each feature map. The 2D convolution being performed is still not a DepthwiseSeparable convolution, but you can choose to make it a DepthwiseSeparable convolution. However, typically with DepthWise separable convolutions you want to include a BatchNormalization layer, but for such a small batch size it may not make sense. (My batch is is 4 images , one per GPU).