djl
djl copied to clipboard
Mask-Rcnn transfer learning example
Any chance of getting a simple example (even just ad hoc ) on how to train mask rcnn via transfer learning on a custom Coco formatted dataset ?
We are currently evaluating if we need to write our main application in python because of this requirement or if we can switch to JVM/Kotlin ( which we would prefer ) and I would really like to do a small POC on this.
Yes there should be a way to achieve that, which DL framework you are looking for? Currently we support MXNet transfer learning and experimental PyTorch Transfer Learning. We have MaskRCNN model in MXNet model zoo. PyTorch one is on the way.
MxNet ist totally fine ....
So here is the steps.
-
Get the MaskRCNN pretrained model: http://docs.djl.ai/mxnet/mxnet-model-zoo/index.html, you can choose one of the backbone. Try to get it run with our instance segmentation example http://docs.djl.ai/mxnet/mxnet-model-zoo/index.html.
-
Prepare your dataset, we do have coco dataset in our DataSet (https://github.com/awslabs/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/CocoDetection.java). You can implement your custom dataset from it
-
Train the model using the dataset. You can just create a Trainer from the model. Since the model is came from MXNet, you can follow the similar steps for the original model training: https://gluon-cv.mxnet.io/build/examples_instance/train_mask_rcnn_coco.html. In DJL, you can spawn a Trainer from the Pretrained Model.
The last part is pretty difficult since MaskRCNN model itself is complex. https://github.com/awslabs/djl/blob/master/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java#L133-L156 You may need to do something similar to what we did for Resnet if the output classes is not the same as the original dataset.
My recommendation here to save the time is to use the GluonCV (python) and change the layers to fit your use cases. Once this is done. You can convert the model into MXNet Symbol and get it trained in Java with the DataSet you implemented.
That sounds like a plan :-) However i think i will have to with modifying the network architecture directly in java... since we plan to train on quite a few datasets with different classes so going the gluon -> mxnet -> dlj route would have to be automated and that in itself would be quite a lot of work.