djl icon indicating copy to clipboard operation
djl copied to clipboard

Object detection

Open warthecatalyst opened this issue 3 years ago • 1 comments
trafficstars

This PR is about my current progress on Object Detection work. Currently, I think I've successfully built the network structure of YOLOv3 and get started with loss function

warthecatalyst avatar Aug 19 '22 08:08 warthecatalyst

emm,this is just a test pr for my mentor Zack to get a view on my current progress and make suggestions

At 2022-08-29 13:50:35, "Frank Liu" @.***> wrote:

@frankfliu commented on this pull request.

In api/src/main/java/ai/djl/training/loss/SingleShotDetectionLoss.java:

@@ -48,6 +48,8 @@ public SingleShotDetectionLoss() { @Override protected Pair<NDList, NDList> inputForComponent( int componentIndex, NDList labels, NDList predictions) {

  •    System.out.println(labels.singletonOrThrow());  //print labels for test
    

Remove testing code

In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:

@@ -0,0 +1,275 @@ +package ai.djl.training.loss;

Add license header

In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:

+import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Activation; +import ai.djl.util.Pair; +import ai.djl.util.PairList;

+import java.util.ArrayList; +import java.util.Arrays; + +public class YOLOv3Loss extends Loss{

./gradlew formatJava

In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:

  •            boxBTrue.get(":,2:").expandDims(0).broadcast(A,B,2));
    
  •    NDArray minXY = NDArrays.maximum(boxATrue.get(":,:2").expandDims(1).broadcast(A,B,2),
    
  •            boxBTrue.get(":,:2").expandDims(0).broadcast(A,B,2));
    
  •    NDArray inter = NDArrays.minimum(maxXY.sub(minXY),0);
    
  •    inter = inter.get(":,:,0").mul(inter.get(":,:,1"));
    
  •    //to calculate the area of prediction bbox and true bbox
    
  •    NDArray areaA = boxATrue.get(":,2").sub(boxATrue.get(":,0")).mul(boxATrue.get(":,3").sub(boxATrue.get(":,1"))).expandDims( 1).broadcast(inter.getShape()),
    
  •            areaB = boxBTrue.get(":,2").sub(boxBTrue.get(":,0")).mul(boxBTrue.get(":,3").sub(boxBTrue.get(":,1"))).expandDims( 0).broadcast(inter.getShape());
    
  •    NDArray union = areaA.add(areaB).sub(inter);
    
  •    return inter.div(union);
    
  • }
  • public ArrayList<PairList<Long,Rectangle>> getTargetFromCurrentLabel(NDArray labels){

Use List instead of ArrayList in function declaration

In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:

  •                float interHeight = Math.min(trueBottom,predBottom)-Math.max(trueTop,predTop);
    
  •                float inter = interWidth*interHeight, union = wgt.get(i).getFloat()*inW*hgt.get(i).getFloat()*inH
    
  •                        + anchors.get(j,0).getFloat()*anchors.get(j,1).getFloat()-inter;
    
  •                iou.set(new NDIndex(curIndex),inter/union);
    
  •            }
    
  •        }
    
  •        return new NDList(iou,boxLossScale,groundTruth);
    
  •    }
    
  •    return null;
    
  • }
  • //calculate IOU is already defined in Rectangle
  • public NDArray calculateIOU(NDArray boxA,NDArray boxB){

javadoc

In examples/src/main/java/ai/djl/examples/training/TrainPikachu.java:

@@ -87,7 +87,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans try (Trainer trainer = model.newTrainer(config)) { trainer.setMetrics(new Metrics());

  •            Shape inputShape = new Shape(arguments.getBatchSize(), 3, 256, 256);
    
  •            Shape inputShape = new Shape(1, 3, 256, 256);
    

Why?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

warthecatalyst avatar Aug 29 '22 05:08 warthecatalyst