djl
djl copied to clipboard
Object detection
This PR is about my current progress on Object Detection work. Currently, I think I've successfully built the network structure of YOLOv3 and get started with loss function
emm,this is just a test pr for my mentor Zack to get a view on my current progress and make suggestions
At 2022-08-29 13:50:35, "Frank Liu" @.***> wrote:
@frankfliu commented on this pull request.
In api/src/main/java/ai/djl/training/loss/SingleShotDetectionLoss.java:
@@ -48,6 +48,8 @@ public SingleShotDetectionLoss() { @Override protected Pair<NDList, NDList> inputForComponent( int componentIndex, NDList labels, NDList predictions) {
-
System.out.println(labels.singletonOrThrow()); //print labels for test
Remove testing code
In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:
@@ -0,0 +1,275 @@ +package ai.djl.training.loss;
Add license header
In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:
+import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Activation; +import ai.djl.util.Pair; +import ai.djl.util.PairList;
+import java.util.ArrayList; +import java.util.Arrays; + +public class YOLOv3Loss extends Loss{
./gradlew formatJava
In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:
boxBTrue.get(":,2:").expandDims(0).broadcast(A,B,2));
-
NDArray minXY = NDArrays.maximum(boxATrue.get(":,:2").expandDims(1).broadcast(A,B,2), -
boxBTrue.get(":,:2").expandDims(0).broadcast(A,B,2)); -
NDArray inter = NDArrays.minimum(maxXY.sub(minXY),0); -
inter = inter.get(":,:,0").mul(inter.get(":,:,1")); -
//to calculate the area of prediction bbox and true bbox -
NDArray areaA = boxATrue.get(":,2").sub(boxATrue.get(":,0")).mul(boxATrue.get(":,3").sub(boxATrue.get(":,1"))).expandDims( 1).broadcast(inter.getShape()), -
areaB = boxBTrue.get(":,2").sub(boxBTrue.get(":,0")).mul(boxBTrue.get(":,3").sub(boxBTrue.get(":,1"))).expandDims( 0).broadcast(inter.getShape()); -
NDArray union = areaA.add(areaB).sub(inter); -
return inter.div(union); - }
- public ArrayList<PairList<Long,Rectangle>> getTargetFromCurrentLabel(NDArray labels){
Use List instead of ArrayList in function declaration
In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:
float interHeight = Math.min(trueBottom,predBottom)-Math.max(trueTop,predTop);
-
float inter = interWidth*interHeight, union = wgt.get(i).getFloat()*inW*hgt.get(i).getFloat()*inH -
+ anchors.get(j,0).getFloat()*anchors.get(j,1).getFloat()-inter; -
iou.set(new NDIndex(curIndex),inter/union); -
} -
} -
return new NDList(iou,boxLossScale,groundTruth); -
} -
return null; - }
- //calculate IOU is already defined in Rectangle
- public NDArray calculateIOU(NDArray boxA,NDArray boxB){
javadoc
In examples/src/main/java/ai/djl/examples/training/TrainPikachu.java:
@@ -87,7 +87,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans try (Trainer trainer = model.newTrainer(config)) { trainer.setMetrics(new Metrics());
-
Shape inputShape = new Shape(arguments.getBatchSize(), 3, 256, 256);
-
Shape inputShape = new Shape(1, 3, 256, 256);
Why?
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>