djl icon indicating copy to clipboard operation
djl copied to clipboard

Training Obj. Detection with custom datasets

Open androuino opened this issue 4 years ago • 16 comments

Question

Hi, my question is... While I have successfully executed the training for obj. detection following the TrainPikachu class but I am getting this result from the console: Screen Shot 2020-10-05 at 11 04 16

As you will notice, the classAccuracy and boundingBoxError are empty or has no value at all.

Then I train the model for 8 epochs and tried to test it but am getting no detection.

I have 3 classes to identify and my Shape is: Shape inputShape = new Shape(arguments.getBatchSize(), 3, 800, 1144);

My index.file is like this:

{"IMG_5237.jpg":[["0","0.085625","0.07473776223776224","0.043750000000000004","0.030594405594405596"],
["0","0.35125","0.28496503496503495","0.055","0.04020979020979021"],
["1","0.26875","0.3618881118881119","0.0575","0.033216783216783216"],
["0","0.7975","0.4602272727272727","0.0575","0.03583916083916084"],
["0","0.08750000000000001","0.6831293706293706","0.065","0.039335664335664336"],
["2","0.114375","0.5607517482517482","0.07875","0.028846153846153848"],
["2","0.115","0.40384615384615385","0.085","0.027972027972027972"],
["2","0.1125","0.3395979020979021","0.08","0.028846153846153848"],
["2","0.11375","0.23251748251748253","0.085","0.02972027972027972"],
["2","0.106875","0.04020979020979021","0.08125","0.026223776223776224"]], ...}

So meaning I have multiple bounding boxes annotated in a single image. the first index is the class name then the second to last are the bounding boxes.

On getting the Record, I am using the same method as with TrainPikachu:

@Override
protected Record get(NDManager manager, long index) throws IOException {
    int idx = Math.toIntExact(index);
    NDList d = new NDList(ImageFactory.getInstance()
            .fromFile(imagePaths.get(idx))
            .toNDArray(manager, flag));
    NDArray label = manager.create(labels.get(idx));
    NDList l = new NDList(label.reshape(new Shape(1).addAll(label.getShape())));
    return new Record(d, l);
}

Then this is how I prepare the dataset that is close to the PikachuDetection class:

try (Reader reader = Files.newBufferedReader(indexFile)) {
    Type mapType = new TypeToken<Map<String, List<String[]>>>() {}.getType();
    Map<String, List<String[]>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
    for (Map.Entry<String, List<String[]>> entry : metadata.entrySet()) {
        String imgName = entry.getKey();
        for (String[] item : entry.getValue()) {
            float[] labelArray = new float[5];
            // Class label
            labelArray[0] = Float.parseFloat(item[0]);

            // Bounding box labels
            labelArray[1] = Float.parseFloat(item[1]);
            labelArray[2] = Float.parseFloat(item[2]);
            labelArray[3] = Float.parseFloat(item[3]);
            labelArray[4] = Float.parseFloat(item[4]);
            labels.add(labelArray);
        }
        imagePaths.add(usagePath.resolve(imgName));
    }
}

Please let me know if I am doing something that is very different from the TrainPikachu example of why my trained model doesn't detect any object. Thank you in advance.

androuino avatar Oct 05 '20 02:10 androuino

@androuino If I understand correctly, you have multiple boundingbox in a single image:

  • The image will be mapped to an NDArray as data, since you only have one input the data NDList will only contains single image
  • Each boundingbox and it's category will be mapped to double[5], this is what you are currently doing
  • since one image can have multiple boundingbox, the label of the image should be double[][5], in the Pikachu examples, there is always single boundingbox, so it have the following code to convert to double[1][5]:
NDList l = new NDList(label.reshape(new Shape(1).addAll(label.getShape())));
  • The number of data should match the number labels, so that can be combined into Records

In your code, the number of lables = number images * boundingbox Here is what in your dataset: 1st record -> 1st image: 1st boundingbox 2nd record -> 2nd image: 2nd boundingbox in the 1st image

Here is the code in CocoDetection dataset which is similar to what you need: https://github.com/awslabs/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/CocoDetection.java#L137-L156

frankfliu avatar Oct 07 '20 05:10 frankfliu

Thanks @frankfliu for the response. However, running the CocoTest gives me this error:

java.lang.IllegalStateException: Expected BEGIN_OBJECT but was BEGIN_ARRAY at line 1 column 1369436 path $.annotations[0].bbox
com.google.gson.JsonSyntaxException: java.lang.IllegalStateException: Expected BEGIN_OBJECT but was BEGIN_ARRAY at line 1 column 1369436 path $.annotations[0].bbox
	at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:226)
	at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$1.read(ReflectiveTypeAdapterFactory.java:131)
	at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:222)
	at com.google.gson.internal.bind.TypeAdapterRuntimeTypeWrapper.read(TypeAdapterRuntimeTypeWrapper.java:41)
	at com.google.gson.internal.bind.CollectionTypeAdapterFactory$Adapter.read(CollectionTypeAdapterFactory.java:82)
	at com.google.gson.internal.bind.CollectionTypeAdapterFactory$Adapter.read(CollectionTypeAdapterFactory.java:61)
	at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$1.read(ReflectiveTypeAdapterFactory.java:131)
	at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:222)
	at com.google.gson.Gson.fromJson(Gson.java:932)
	at com.google.gson.Gson.fromJson(Gson.java:870)
	at ai.djl.basicdataset.CocoUtils.prepare(CocoUtils.java:56)
	at ai.djl.basicdataset.CocoDetection.prepare(CocoDetection.java:109)
	at ai.djl.training.dataset.Dataset.prepare(Dataset.java:40)
	at ai.djl.training.dataset.RandomAccessDataset.getData(RandomAccessDataset.java:83)
	at ai.djl.training.Trainer.iterateDataset(Trainer.java:130)
	at ai.djl.basicdataset.CocoTest.testCocoRemote(CocoTest.java:40)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
	at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
	at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
	at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
	at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
	at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
	at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
	at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
	at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
	at org.testng.TestRunner.privateRun(TestRunner.java:766)
	at org.testng.TestRunner.run(TestRunner.java:587)
	at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
	at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
	at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
	at org.testng.SuiteRunner.run(SuiteRunner.java:286)
	at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
	at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
	at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
	at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
	at org.testng.TestNG.runSuites(TestNG.java:1039)
	at org.testng.TestNG.run(TestNG.java:1007)
	at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
	at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
	at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
	at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
	at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
	at com.sun.proxy.$Proxy5.stop(Unknown Source)
	at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
	at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
	at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
	at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
	at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
	at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
	at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: java.lang.IllegalStateException: Expected BEGIN_OBJECT but was BEGIN_ARRAY at line 1 column 1369436 path $.annotations[0].bbox
	at com.google.gson.stream.JsonReader.beginObject(JsonReader.java:386)
	at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:215)
	... 68 more

I did not change any of the code inside CocoDetection class except for downloading the coco dataset to run the test.

androuino avatar Oct 07 '20 07:10 androuino

Here is the code in CocoDetection dataset which is similar to what you need: https://github.com/awslabs/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/CocoDetection.java#L137-L156

I have checked the actual CocoDetection class and modified my code accordingly which is a bit similar to CocoDetection:

usagePath = root.resolve(usagePath);
Path indexFile = usagePath.resolve("index.file");
try (Reader reader = Files.newBufferedReader(indexFile)) {
    Type mapType = new TypeToken<Map<String, List<String[]>>>() {}.getType();
    Map<String, List<String[]>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
    for (Map.Entry<String, List<String[]>> entry : metadata.entrySet()) {
        List<double[]> labelOfImage = getLabels(entry.getValue());
        if (!labelOfImage.isEmpty()) {
            imagePaths.add(usagePath.resolve(entry.getKey()));
            labels.add(labelOfImage.toArray(new double[0][]));
        }
    }
}

then the function that has a similarity with CocoDetection class.

private double[] convertRecToList(String[] anno) {
    double[] list = new double[5];
    list[1] = Double.parseDouble(anno[1]);
    list[2] = Double.parseDouble(anno[2]);
    list[3] = Double.parseDouble(anno[3]);
    list[4] = Double.parseDouble(anno[4]);
    return list;
}

private List<double[]> getLabels(List<String[]> arr) {
    List<double[]> label = new ArrayList<>();
    for (String[] item : arr) {
        double[] list = convertRecToList(item);
        // add the category label
        // map the original one to incremental index
        list[0] = Double.parseDouble(item[0]);
        label.add(list);
    }
    return label;
}

But am getting this error:

MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

	at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1808)
	at ai.djl.mxnet.jna.JnaUtils.imperativeInvoke(JnaUtils.java:502)
	at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:91)
	at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:75)
	at ai.djl.mxnet.engine.MxNDManager.invoke(MxNDManager.java:288)
	at ai.djl.mxnet.engine.MxNDArrayEx.multiBoxTarget(MxNDArrayEx.java:934)
	at ai.djl.modality.cv.MultiBoxTarget.target(MultiBoxTarget.java:74)
	at ai.djl.training.loss.SingleShotDetectionLoss.inputForComponent(SingleShotDetectionLoss.java:53)
	at ai.djl.training.loss.AbstractCompositeLoss.evaluate(AbstractCompositeLoss.java:66)
	at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:82)
	at ai.djl.training.EasyTrain.fit(EasyTrain.java:45)
	at ai.djl.examples.training.TrainCustomModel.runTraining(TrainCustomModel.java:77)
	at ai.djl.examples.training.TrainCustomModelTest.testTrainingCustomModel(TrainCustomModelTest.java:31)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
	at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
	at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
	at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
	at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
	at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
	at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
	at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
	at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
	at org.testng.TestRunner.privateRun(TestRunner.java:766)
	at org.testng.TestRunner.run(TestRunner.java:587)
	at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
	at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
	at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
	at org.testng.SuiteRunner.run(SuiteRunner.java:286)
	at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
	at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
	at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
	at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
	at org.testng.TestNG.runSuites(TestNG.java:1039)
	at org.testng.TestNG.run(TestNG.java:1007)
	at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
	at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
	at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
	at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
	at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
	at com.sun.proxy.$Proxy5.stop(Unknown Source)
	at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
	at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
	at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
	at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
	at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
	at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
	at java.base/java.lang.Thread.run(Thread.java:834)

Sorry, I may not understand what you are suggesting from your last comment but please help me get through with this error. Thank you in advance.

androuino avatar Oct 07 '20 08:10 androuino

@androuino I confirmed CocoTest is failing, we will take a look and fix CocoDetection dataset bug.

Training ssd with Coco is not as straight-forward as Pikachu, we will trying to create an example.

frankfliu avatar Oct 08 '20 16:10 frankfliu

Thanks for the response @frankfliu. I confirmed that the CocoDetection has been fixed. However, have you taken a look at the error I am getting?

MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

So as of the moment, training a custom dataset with multiple bbox and classes annotated in one image doesn't simply work or supported yet following the Pikachu example?

androuino avatar Oct 09 '20 01:10 androuino

@androuino it looks like you use the operator that doesn't support float64 data type. Do you know where you use float64?

stu1130 avatar Oct 09 '20 02:10 stu1130

Hi @stu1130, this is the class that I made which is kind of a combination of PikachuDetection class and the CocoDetection class: https://gist.github.com/androuino/00095ba5be3d10cab765bd2447d236cf Then this is my TrainCustomModel class: https://gist.github.com/androuino/6e8e014e3e70b107f4ebb9cf6c9387ac which I don't see or notice that I am using float64. My dataset structure if pretty much similar to the Pikachu dataset as well as the annotation values.

{"IMG_5237.jpg":[["0","0.085625","0.07473776223776224","0.043750000000000004","0.030594405594405596"],
["0","0.35125","0.28496503496503495","0.055","0.04020979020979021"],
["1","0.26875","0.3618881118881119","0.0575","0.033216783216783216"],
["0","0.7975","0.4602272727272727","0.0575","0.03583916083916084"],
["0","0.08750000000000001","0.6831293706293706","0.065","0.039335664335664336"],
["2","0.114375","0.5607517482517482","0.07875","0.028846153846153848"],
["2","0.115","0.40384615384615385","0.085","0.027972027972027972"],
["2","0.1125","0.3395979020979021","0.08","0.028846153846153848"],
["2","0.11375","0.23251748251748253","0.085","0.02972027972027972"],
["2","0.106875","0.04020979020979021","0.08125","0.026223776223776224"]], ...}

androuino avatar Oct 09 '20 02:10 androuino

@androuino Does you have complete error stack trace? The label of the Pikachu dataset is float, but the label in CocoDetection is double. I am suspecting this is where you use float64, could you cast the label to float32 by toType method?

stu1130 avatar Oct 09 '20 03:10 stu1130

@stu1130 So far this is the only stack trace that I could get:

MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

	at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1808)
	at ai.djl.mxnet.jna.JnaUtils.imperativeInvoke(JnaUtils.java:502)
	at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:91)
	at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:75)
	at ai.djl.mxnet.engine.MxNDManager.invoke(MxNDManager.java:288)
	at ai.djl.mxnet.engine.MxNDArrayEx.multiBoxTarget(MxNDArrayEx.java:934)
	at ai.djl.modality.cv.MultiBoxTarget.target(MultiBoxTarget.java:74)
	at ai.djl.training.loss.SingleShotDetectionLoss.inputForComponent(SingleShotDetectionLoss.java:53)
	at ai.djl.training.loss.AbstractCompositeLoss.evaluate(AbstractCompositeLoss.java:66)
	at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:82)
	at ai.djl.training.EasyTrain.fit(EasyTrain.java:45)
	at ai.djl.examples.training.TrainCustomModel.runTraining(TrainCustomModel.java:77)
	at ai.djl.examples.training.TrainCustomModelTest.testTrainingCustomModel(TrainCustomModelTest.java:31)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
	at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
	at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
	at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
	at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
	at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
	at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
	at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
	at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
	at org.testng.TestRunner.privateRun(TestRunner.java:766)
	at org.testng.TestRunner.run(TestRunner.java:587)
	at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
	at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
	at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
	at org.testng.SuiteRunner.run(SuiteRunner.java:286)
	at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
	at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
	at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
	at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
	at org.testng.TestNG.runSuites(TestNG.java:1039)
	at org.testng.TestNG.run(TestNG.java:1007)
	at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
	at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
	at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
	at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
	at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
	at com.sun.proxy.$Proxy5.stop(Unknown Source)
	at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
	at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
	at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
	at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
	at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
	at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
	at java.base/java.lang.Thread.run(Thread.java:834)

I will try to cast the label to float32 and get back to you. Thanks.

androuino avatar Oct 09 '20 03:10 androuino

@stu1130 I tried your suggestion and this is the error am getting: This is the function looks like after I cast the label to float:

private float[] convertRecToFloatList(String[] anno) {
    float[] list = new float[5];
    list[1] = Float.parseFloat(anno[1]);
    list[2] = Float.parseFloat(anno[2]);
    list[3] = Float.parseFloat(anno[3]);
    list[4] = Float.parseFloat(anno[4]);
    return list;
}

private List<float[]> getLabelsAsFloat(List<String[]> arr) {
    List<float[]> label = new ArrayList<>();
    for (String[] item : arr) {
        float[] box = convertRecToFloatList(item);
        logger.info(Arrays.toString(box));
        // add the category label
        // map the original one to incremental index
        float[] list = new float[5];
        System.arraycopy(box, 1, list, 1, 4);
        list[0] = Float.parseFloat(item[0]);
        label.add(list);
    }
    return label;
}

Stack trace:

XNet engine call failed: TBlob.get_with_shape: Check failed: this->shape_.Size() == static_cast<size_t>(shape.Size()) (65 vs. 285) : new and old shape do not match total elements
Stack trace:
  File "../include/mxnet/./tensor_blob.h", line 311

ai.djl.engine.EngineException: MXNet engine call failed: TBlob.get_with_shape: Check failed: this->shape_.Size() == static_cast<size_t>(shape.Size()) (65 vs. 285) : new and old shape do not match total elements
Stack trace:
  File "../include/mxnet/./tensor_blob.h", line 311

	at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1808)
	at ai.djl.mxnet.jna.JnaUtils.syncCopyToCPU(JnaUtils.java:475)
	at ai.djl.mxnet.engine.MxNDArray.toByteBuffer(MxNDArray.java:280)
	at ai.djl.ndarray.NDArray.toLongArray(NDArray.java:300)
	at ai.djl.ndarray.NDArray.getLong(NDArray.java:558)
	at ai.djl.training.evaluator.AbstractAccuracy.lambda$updateAccumulator$1(AbstractAccuracy.java:85)
	at java.base/java.util.concurrent.ConcurrentHashMap.compute(ConcurrentHashMap.java:1932)
	at ai.djl.training.evaluator.AbstractAccuracy.updateAccumulator(AbstractAccuracy.java:85)
	at ai.djl.training.listener.EvaluatorTrainingListener.updateEvaluators(EvaluatorTrainingListener.java:147)
	at ai.djl.training.listener.EvaluatorTrainingListener.onTrainingBatch(EvaluatorTrainingListener.java:114)
	at ai.djl.training.EasyTrain.lambda$trainBatch$1(EasyTrain.java:92)
	at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
	at ai.djl.training.Trainer.notifyListeners(Trainer.java:263)
	at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:92)
	at ai.djl.training.EasyTrain.fit(EasyTrain.java:45)
	at ai.djl.examples.training.TrainCustomModel.runTraining(TrainCustomModel.java:81)
	at ai.djl.examples.training.TrainCustomModelTest.testTrainingCustomModel(TrainCustomModelTest.java:31)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
	at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
	at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
	at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
	at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
	at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
	at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
	at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
	at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
	at org.testng.TestRunner.privateRun(TestRunner.java:766)
	at org.testng.TestRunner.run(TestRunner.java:587)
	at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
	at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
	at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
	at org.testng.SuiteRunner.run(SuiteRunner.java:286)
	at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
	at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
	at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
	at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
	at org.testng.TestNG.runSuites(TestNG.java:1039)
	at org.testng.TestNG.run(TestNG.java:1007)
	at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
	at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
	at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
	at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
	at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
	at com.sun.proxy.$Proxy5.stop(Unknown Source)
	at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
	at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
	at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
	at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
	at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
	at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
	at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
	at java.base/java.lang.Thread.run(Thread.java:834)

androuino avatar Oct 09 '20 04:10 androuino

@stu1130 I've some update about the training attempt when I changed the argument's value like this:

args = new String[] {"-e", "8", "-m", "1", "-b", "1"};

It runs the training, however, the training and validation's boundingBoxError doesn't seem improving. These are my throughout training for 8 epochs:

[INFO ] - Load MXNet Engine Version 1.7.0 in 0.227 ms.
[INFO ] - Epoch 1 finished.
[INFO ] - Train: classAccuracy: 0.33, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 1.68
[INFO ] - Validate: classAccuracy: 0.78, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.90
[INFO ] - Epoch 2 finished.
[INFO ] - Train: classAccuracy: 0.61, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.90
[INFO ] - Validate: classAccuracy: 0.87, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.81
[INFO ] - Epoch 3 finished.
[INFO ] - Train: classAccuracy: 0.86, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.50
[INFO ] - Validate: classAccuracy: 0.93, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.77
[INFO ] - Epoch 4 finished.
[INFO ] - Train: classAccuracy: 0.94, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.29
[INFO ] - Validate: classAccuracy: 0.96, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.72
[INFO ] - Epoch 5 finished.
[INFO ] - Train: classAccuracy: 0.99, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.17
[INFO ] - Validate: classAccuracy: 0.99, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.66
[INFO ] - Epoch 6 finished.
[INFO ] - Train: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.09
[INFO ] - Validate: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.58
[INFO ] - Epoch 7 finished.
[INFO ] - Train: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.06
[INFO ] - Validate: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.49
[INFO ] - Epoch 8 finished.
[INFO ] - Train: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.03
[INFO ] - Validate: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.38
[INFO ] - forward P50: 849.174 ms, P90: 878.440 ms
[INFO ] - training-metrics P50: 0.005 ms, P90: 0.034 ms
[INFO ] - backward P50: 6.235 ms, P90: 11.146 ms
[INFO ] - step P50: 25.011 ms, P90: 52.742 ms
[INFO ] - epoch P50: 6.627 s, P90: 8.692 s

Should I raise a concern or not about the boundingBoxError, is it normal or there's something that is not right? Thanks for the help.

update: I tried to test the trained model but it doesn't detect any object.

androuino avatar Oct 09 '20 05:10 androuino

Looks like it's working now with this arguments settings: args = new String[] {"-e", "8", "-b", "1"};.

androuino avatar Oct 09 '20 05:10 androuino

"-m 1" a.k.a "max-batches" means we train the model with only 1 max-batches for each epoch, which is usually for sanity test.

stu1130 avatar Oct 09 '20 07:10 stu1130

I see, thanks for the info @stu1130. However, how could determine if its an ideal time to stop the training? What should I look for at the values during training? Thanks.

androuino avatar Oct 09 '20 07:10 androuino

@stu1130, could you also explain to me the Pikachu's index.file values?

"img_0.jpg": [4.0, 5.0, 512.0, 512.0, 0.0, 0.604744553565979, 0.40195202827453613, 0.6948338747024536, 0.5354305505752563],

I suppose that 512.0 is the image's height and width. How about the 4.0, and 5.0? What are these values? Thanks.

androuino avatar Oct 09 '20 13:10 androuino

@androuino The pikachu example is a simplified version, it doesn't require image augmentations since it only has one class. However, training a proper ssd model require many image augmentations when loading Record from dataset, you can find python code here: https://github.com/dmlc/gluon-cv/blob/master/gluoncv/data/transforms/presets/ssd.py#L98

The full python training code can be found: https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/ssd/train_ssd.py

frankfliu avatar Oct 09 '20 21:10 frankfliu

Closing due to inactivity. Please reopen if you still require help.

siddvenk avatar Nov 09 '22 20:11 siddvenk