djl
djl copied to clipboard
Paddle YoloV3 用DJL推断很慢
当输入图片尺寸为416 * 416时,检测一张图片大约需要7秒。尺寸调整为224 * 224时,一张图片大约为2秒。均远低于Yolo模型在Python中的性能,请问有没有可以提高推断性能的方法?
PS:推断结果是正确的
我的推断代码如下:
public static void main(String[] args) throws IOException, ModelNotFoundException, MalformedModelException, TranslateException {
List<Image> images = new ArrayList<>();
for(int i = 1;i < 30; i++){
if(i<10){
String url = "C:\\Users\\mj\\Desktop\\djldemo1\\src\\main\\resources\\test\\000" +i+ ".jpg";
Image img = ImageFactory.getInstance().fromFile(Paths.get(url));
images.add(img);
}else{
String url = "C:\\Users\\mj\\Desktop\\djldemo1\\src\\main\\resources\\test\\00" +i+ ".jpg";
Image img = ImageFactory.getInstance().fromFile(Paths.get(url));
images.add(img);
}
}
Criteria<Image, DetectedObjects> criteria = Criteria.builder()
.optEngine("PaddlePaddle")
.setTypes(Image.class, DetectedObjects.class)
.optModelUrls("file:/C:/Users/mj/Desktop/djldemo1/model/yolov3.zip")
.optTranslator(new FaceTranslator(1F,0.3F))
.optProgress(new ProgressBar())
.build();
var model = criteria.loadModel();
var predictor = model.newPredictor();
Long begin = System.currentTimeMillis();
List<Long> times = new ArrayList<>();
times.add(begin);
System.out.println(begin);
for(var img : images){
DetectedObjects inferenceResult = predictor.predict(img);
Long time = System.currentTimeMillis();
System.out.println(time);
times.add(time);
}
long end = System.currentTimeMillis();
times.add(end);
System.out.println(end - begin);
System.out.println(times);
}
我的测试环境是IDEA + WIN10 + CPU(AMD R7 5800H)
万分感谢
在同样的环境下使用DJL demo中提供的人脸检测模型检测同样的30张图片,每张图片大约需要400毫秒。
人脸检测模型压缩后大小约为7MB, 而我使用的YOLOv3模型压缩后大小约为230MB, 请问这是正常情况吗?
@frankfliu
- try to use opencv extension, it will improve image processing speed: https://github.com/deepjavalibrary/djl/tree/master/extensions/opencv, you just add this package to your project, no code change required
- In your code, you only measured inference time, so it may not related to image processing. You can use the following code to get pre-processing/post-processing time:
Metrics metrics = new Metrics();
predictor.setMetrics(metrics);
for () {
...
}
System.out.println(metrics.percentile("Total", 50));
System.out.println(metrics.percentile("Preprocess", 50));
System.out.println(metrics.percentile("Inference", 50));
System.out.println(metrics.percentile("Postprocess", 50));
- You can also try to use our djl-bench to benchmark your model's performance: https://docs.djl.ai/master/docs/serving/benchmark/index.html
- try to use opencv extension, it will improve image processing speed: https://github.com/deepjavalibrary/djl/tree/master/extensions/opencv, you just add this package to your project, no code change required
- In your code, you only measured inference time, so it may not related to image processing. You can use the following code to get pre-processing/post-processing time:
Metrics metrics = new Metrics(); predictor.setMetrics(metrics); for () { ... } System.out.println(metrics.percentile("Total", 50)); System.out.println(metrics.percentile("Preprocess", 50)); System.out.println(metrics.percentile("Inference", 50)); System.out.println(metrics.percentile("Postprocess", 50));
- You can also try to use our djl-bench to benchmark your model's performance: https://docs.djl.ai/docs/serving/benchmark/index.html
Thank you!I have tried the benchmark script. It works well on the face detection model with command as: .\benchmark -c 80 -s 1,3,224,224 -u "https://aias-home.oss-cn-beijing.aliyuncs.com/models/face_mask/face_detection.zip" -e PaddlePaddle
However, when I change to use my own model, it requires three input as [im_shape, image, scale_factor],so my command also changes as the tutorial said:
.\benchmark -u "file:/C:/Users/mj/Desktop/djldemo1/model/fall.zip" -e PaddlePaddle -s (1,2),(1,3,224,224),(1,2) -c 80
Then, an error happend:
Exception in thread "main" java.lang.NumberFormatException: For input string: "1 2"
at java.base/java.lang.NumberFormatException.forInputString(NumberFormatException.java:65)
at java.base/java.lang.Long.parseLong(Long.java:692)
at java.base/java.lang.Long.parseLong(Long.java:817)
at java.base/java.util.stream.ReferencePipeline$5$1.accept(ReferencePipeline.java:229)
at java.base/java.util.Spliterators$ArraySpliterator.forEachRemaining(Spliterators.java:948)
at java.base/java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:484)
at java.base/java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:474)
at java.base/java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:550)
at java.base/java.util.stream.AbstractPipeline.evaluateToArrayNode(AbstractPipeline.java:260)
at java.base/java.util.stream.LongPipeline.toArray(LongPipeline.java:521)
at ai.djl.benchmark.NDListGenerator.parseShape(NDListGenerator.java:136)
at ai.djl.benchmark.Arguments.
I am wondering how I should write the command in the right form.
Also, I measured the pre-processing/post-processing time using the snippet you provided. The result showed that almost 95% time spent on inference. 4% time spent on preprocessing. So I am wondering whether it is the true performance of YOLOv3 on CPU. If I want to test on GPU, what dependecies I should modify in the pom.xml except the installation of CUDA and CUDAA. By now If I run nvcc-V on my computer, it could print the GPU's info. But DJL seems couldn't feel the existence of GPU, I think it might because the Version of CUDA driver doesn't match the DJL's engine version? How could I check that?
Thanks for your reply. @frankfliu
The command looks correct, however, the stracktrace indicate that your shape input contains space in it (something like (1 2)
). Can you check if you typed wrong.
@JiMa98 你好,你在使用paddlepaddle引擎的时候,有没有遇到过下面的错误。 org.springframework.web.util.NestedServletException: Handler dispatch failed; nested exception is java.lang.UnsatisfiedLinkError: C:\Users\wangkedong.djl.ai\paddle\2.0.2-cpu-win-x86_64\0.26.0-djl_paddle.dll: 找不到指定的程序。 at org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:1082) at org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:963) at org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:1006) at org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:909)
下面是我的pom