yolov5 icon indicating copy to clipboard operation
yolov5 copied to clipboard

how to make preprocess for images in c++

Open fatejzz opened this issue 2 years ago • 8 comments

Search before asking

  • [X] I have searched the YOLOv5 issues and discussions and found no similar questions.

Question

At present, I am trying to deploy my Yolo model in ros. And I use c++ for my Yolo model to make predictions. However, it doesn't work well. Given the same image, when I use detect.py it can make correct predictions. But when I use my code it can't predict well for the existing wrong box. I think it is because of the image preprocessing process, I imitate the letterbox in detect.py, but it will meet a problem.

File "code/torch/models/yolo.py", line 71, in forward _35 = (_20).forward(_34, ) _36 = (_22).forward((_21).forward(_35, ), _29, ) _37 = (_24).forward(_33, _35, (_23).forward(_36, ), ) ~~~~~~~~~~~~ <--- HERE return (_37,) class Detect(Module): File "code/torch/models/yolo.py", line 101, in forward _19 = torch.split_with_sizes(torch.sigmoid(_18), [2, 2, 6], 4) xy, wh, conf, = _19 _20 = torch.add(torch.mul(xy, CONSTANTS.c0), CONSTANTS.c1) ~~~~~~~~~ <--- HERE xy0 = torch.mul(_20, torch.select(CONSTANTS.c2, 0, 0)) _21 = torch.pow(torch.mul(wh, CONSTANTS.c0), 2)

Traceback of TorchScript, original code (most recent call last): /home/jzz/yolov5/models/yolo.py(71): forward /data/jzz/envs/jzz/lib/python3.7/site-packages/torch/nn/modules/module.py(1098): _slow_forward /data/jzz/envs/jzz/lib/python3.7/site-packages/torch/nn/modules/module.py(1110): _call_impl /home/jzz/yolov5/models/yolo.py(158): _forward_once /home/jzz/yolov5/models/yolo.py(135): forward /data/jzz/envs/jzz/lib/python3.7/site-packages/torch/nn/modules/module.py(1098): _slow_forward /data/jzz/envs/jzz/lib/python3.7/site-packages/torch/nn/modules/module.py(1110): _call_impl /data/jzz/envs/jzz/lib/python3.7/site-packages/torch/jit/_trace.py(965): trace_module /data/jzz/envs/jzz/lib/python3.7/site-packages/torch/jit/_trace.py(750): trace export.py(98): export_torchscript export.py(520): run /data/jzz/envs/jzz/lib/python3.7/site-packages/torch/autograd/grad_mode.py(27): decorate_context export.py(602): main export.py(607): RuntimeError: The size of tensor a (48) must match the size of tensor b (80) at non-singleton dimension 2

cv::Mat resized_frame = letterbox(frame); cvtColor(resized_frame,resized_frame,CV_BGR2RGB); torch::Tensor in_tensor = torch::from_blob(resized_frame.data, {resized_frame.rows,resized_frame.cols, 3}, torch::kByte); but if the code in the third line wrote as
torch::Tensor in_tensor = torch::from_blob(resized_frame.data, {640,640, 3}, torch::kByte); it could work but there are some problems. On the one hand, it will predict wrong box which doesn't occur in python detect.py. On the other hand, sometimes it will predict different boxes given the same image.

Additional

here is my code of letter box cv::Mat letterbox(const cv::Mat& src) { int in_w = src.cols; int in_h = src.rows; int tar_w = kIMAGE_W_; int tar_h = kIMAGE_H_; float r = min(float(tar_h) / in_h, float(tar_w) / in_w); r=min(r,float(1)); int inside_w = round(in_w * r); int inside_h = round(in_h * r); int padd_w = tar_w - inside_w; int padd_h = tar_h - inside_h; padd_w=padd_w%64;
padd_h=padd_h%64;
cv::Mat resize_img; cv::resize(src, resize_img, cv::Size(inside_w, inside_h)); //cvtColor(resize_img, resize_img, COLOR_BGR2RGB);

    padd_w = padd_w / 2;
    padd_h = padd_h / 2;
    // // std::cout<<"padd_w"<<padd_w<<"padd_h"<<padd_h<<std::endl;
    int top = int(round(padd_h - 0.1));
    int bottom = int(round(padd_h + 0.1));
    int left = int(round(padd_w - 0.1));
    int right = int(round(padd_w + 0.1));
    cv::copyMakeBorder(resize_img, resize_img, top, bottom, left, right, cv::BORDER_CONSTANT, cv::Scalar(114, 114, 114));
    return resize_img;
}

fatejzz avatar Jul 01 '22 08:07 fatejzz

👋 Hello @fatejzz, thank you for your interest in YOLOv5 🚀! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://ultralytics.com or email [email protected].

Requirements

Python>=3.7.0 with all requirements.txt installed including PyTorch>=1.7. To get started:

git clone https://github.com/ultralytics/yolov5  # clone
cd yolov5
pip install -r requirements.txt  # install

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), validation (val.py), inference (detect.py) and export (export.py) on macOS, Windows, and Ubuntu every 24 hours and on every commit.

github-actions[bot] avatar Jul 01 '22 08:07 github-actions[bot]

@fatejzz I'm sorry, we don't have resources to review custom code, but we have a few YOLOv5 C++ Inference examples on ONNX and OpenVINO exported models here:

C++ Inference

YOLOv5 OpenCV DNN C++ inference on exported ONNX model examples:

  • https://github.com/Hexmagic/ONNX-yolov5/blob/master/src/test.cpp
  • https://github.com/doleron/yolov5-opencv-cpp-python

YOLOv5 OpenVINO C++ inference examples:

  • https://github.com/dacquaviva/yolov5-openvino-cpp-python

See Export tutorial for details:

YOLOv5 Tutorials

Good luck 🍀 and let us know if you have any other questions!

glenn-jocher avatar Jul 01 '22 11:07 glenn-jocher

Hi @fatejzz , you can also refer following implementation, we rewrite it with OpenCV's C++ API .

https://github.com/zhiqwang/yolov5-rt-stack/blob/293b378fa2c7d1bc76fac75309b62a951680ac35/deployment/tensorrt/main.cpp#L80-L123

float letterbox(
    const cv::Mat& image,
    cv::Mat& out_image,
    const cv::Size& new_shape = cv::Size(640, 640),
    int stride = 32,
    const cv::Scalar& color = cv::Scalar(114, 114, 114),
    bool fixed_shape = false,
    bool scale_up = true) {
  cv::Size shape = image.size();
  float r = std::min(
      (float)new_shape.height / (float)shape.height, (float)new_shape.width / (float)shape.width);
  if (!scale_up) {
    r = std::min(r, 1.0f);
  }

  int newUnpad[2]{
      (int)std::round((float)shape.width * r), (int)std::round((float)shape.height * r)};

  cv::Mat tmp;
  if (shape.width != newUnpad[0] || shape.height != newUnpad[1]) {
    cv::resize(image, tmp, cv::Size(newUnpad[0], newUnpad[1]));
  } else {
    tmp = image.clone();
  }

  float dw = new_shape.width - newUnpad[0];
  float dh = new_shape.height - newUnpad[1];

  if (!fixed_shape) {
    dw = (float)((int)dw % stride);
    dh = (float)((int)dh % stride);
  }

  dw /= 2.0f;
  dh /= 2.0f;

  int top = int(std::round(dh - 0.1f));
  int bottom = int(std::round(dh + 0.1f));
  int left = int(std::round(dw - 0.1f));
  int right = int(std::round(dw + 0.1f));
  cv::copyMakeBorder(tmp, out_image, top, bottom, left, right, cv::BORDER_CONSTANT, color);

  return 1.0f / r;
}

zhiqwang avatar Jul 01 '22 15:07 zhiqwang

@zhiqwang @glenn-jocher when I convert the input image into the tensor according to the image shape and input it into the model, there will be some errors like the above. But when I created the 640x640x3 tensor, the model could run.

fatejzz avatar Jul 04 '22 07:07 fatejzz

I just quote the two codes, and it can work well until now. but the 'letterbox' has some differences in mechanisms from others. padd_w=padd_w%64; padd_h=padd_h%64;

fatejzz avatar Jul 05 '22 09:07 fatejzz

@glenn-jocher After discussing it with others, I am wondering whether it is because when I use torchscript to export the model, the model's input size has been limited like [640,640].

fatejzz avatar Jul 06 '22 06:07 fatejzz

@fatejzz yes, most exports require fixed input sizes. I think only PyTorch and ONNX --dynamic support dynamic input sizes.

glenn-jocher avatar Jul 07 '22 11:07 glenn-jocher

👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs.

Access additional YOLOv5 🚀 resources:

  • Wiki – https://github.com/ultralytics/yolov5/wiki
  • Tutorials – https://docs.ultralytics.com/yolov5
  • Docs – https://docs.ultralytics.com

Access additional Ultralytics ⚡ resources:

  • Ultralytics HUB – https://ultralytics.com/hub
  • Vision API – https://ultralytics.com/yolov5
  • About Us – https://ultralytics.com/about
  • Join Our Team – https://ultralytics.com/work
  • Contact Us – https://ultralytics.com/contact

Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed!

Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐!

github-actions[bot] avatar Aug 07 '22 00:08 github-actions[bot]