yolov5 icon indicating copy to clipboard operation
yolov5 copied to clipboard

Detect.py supports running against a Triton container

Open gaziqbal opened this issue 2 years ago • 5 comments

This PR enables detect.py to use a Triton for inference. The Triton Inference Server (https://github.com/triton-inference-server/server) is an open source inference serving software that streamlines AI inferencing.

The user can now provide a "--triton-url" argument to detect.py to use a local or remote Triton server for inference. For e.g., http://localhost:8000 will use http over port 8000 and grpc://localhost:8001 will use grpc over port 8001. Note, it is not necessary to specify a weights file to detect.py when using Triton for inference.

A Triton container can be created by first exporting the Yolov5 model to a Triton supported runtime. Onnx, Torchscript, TensorRT are supported by both Triton and the export.py script.

The exported model can then be containerized via the OctoML CLI. See https://github.com/octoml/octo-cli#getting-started for a guide.

python export.py --include onnx # exports the onnx model as yolov5.onnx mkdir octoml && cd octoml && mv ../yolov5s.onnx . #create an octoml folder and moves the onnx model into it octoml init && octoml package && octoml deploy python ../detect.py --triton-url http://localhost:8000

gaziqbal avatar Aug 30 '22 18:08 gaziqbal

@glenn-jocher , @AyushExel - here is a PR against the yolov5 repo.

gaziqbal avatar Aug 30 '22 18:08 gaziqbal

@glenn-jocher , @AyushExel - here is a PR against the yolov5 repo.

Please let me know if you need anything more here.

gaziqbal avatar Sep 07 '22 17:09 gaziqbal

@gaziqbal thanks, we should be reviewing this soon, no changes required ATM

glenn-jocher avatar Sep 10 '22 15:09 glenn-jocher

@gaziqbal thanks for your patience.

I think I'm going to try to refactor this to not treat triton backends differently. There's a tendency for new users to introduce more code than may be required for their feature as they treat it specially compared to existing features, but with 12 different inference types all using a single --weights argument I'd rather not introduce additional command line arguments and function arguments for one more.

Just like --source and --weights are multi-purpose I think we can extend them to triton inference as well, I'll see what I can do here today.

glenn-jocher avatar Sep 21 '22 12:09 glenn-jocher

@gaziqbal ok I've made all my updates!

This moves Triton URL passing to --weights http:// or grpc:// and generalizes Triton support to all CLI predictors, i.e. detect.py, classify/predict.py and segment/predict.py. It makes an assumption that only Triton URLs will contain 'http' or 'grpc', so it's possible that local weights containing those strings will cause an error, but this would be an edge case I think.

Do you know if Triton URLs are ever https?

Can you test to make sure I haven't broken anything? I don't have a triton server setup locally to test. Thanks!

glenn-jocher avatar Sep 21 '22 17:09 glenn-jocher

@gaziqbal pinging you to see if you could re-test after my updates (I hope I didn't break anything)!

glenn-jocher avatar Sep 22 '22 22:09 glenn-jocher

@glenn-jocher - the triton server detection broke because it was using the Path.name property for matching which would strip out any http:// or grpc:// prefixes. I also needed to change the Triton server class to query the model name because the weights parameter is being used for the url. Can you please take a look again? I have verified http and grpc on my end.

gaziqbal avatar Sep 23 '22 16:09 gaziqbal

@gaziqbal understood. Is there a public server URL I could temporarily use for debugging? I see an error from Vanessa that I'm working on now.

glenn-jocher avatar Sep 23 '22 21:09 glenn-jocher

@gaziqbal I took a look, everything looks good to merge over here. Do your updates fix Vanessa's issue?

glenn-jocher avatar Sep 23 '22 21:09 glenn-jocher

@gaziqbal PR is merged. Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐

glenn-jocher avatar Sep 23 '22 22:09 glenn-jocher

@gaziqbal @glenn-jocher I tried but in case of trition servering a series of models according to the code, it defaults to the first model not the one named "yolov5", I think add parameter model_name in TritonRemoteModel

kingkong135 avatar Oct 04 '22 02:10 kingkong135

Good point. That's fairly straightforward to do for TritonRemoteModel. Are you invoking it via detect.py? If so, we'll need a way to relay that.

gaziqbal avatar Oct 04 '22 02:10 gaziqbal

i'm thinking there are 2 ways 1 is to add a new parameter model_name but it's a bit redundant, another way is to pass the end in "weights" like "grpc://localhost:8001/yolov5" and in TritonRemoteModel will handle it.

kingkong135 avatar Oct 04 '22 02:10 kingkong135

My concern with the latter is that it would be a contrived URI schema and not match canonical Triton URIs which may be confusing. That said, the approach is worth exploring more.

gaziqbal avatar Oct 04 '22 02:10 gaziqbal

Stupid question here. Could we use the URL question mark structure for passing variables, i.e. something like this to allow more arguments into the triton server?

grpc://localhost:8001/?model=yolov5s.pt&conf=0.25&imgsz=640

glenn-jocher avatar Oct 04 '22 20:10 glenn-jocher

Hi! Where can I find any info on how exactly triton should be configured for working with this solution? I used triton with custom client. I tried to use my triton backend with detect.py and got issue: tritonclient.utils.InferenceServerException: got unexpected numpy array shape [1, 3, 640, 640], expected [-1, 3, 640, 640]

Here is my config:

name: "yolov5"
platform: "tensorrt_plan"
max_batch_size: 1
input [
  {
    name: "images"
    data_type: TYPE_FP32
    dims: [ 3, 640, 640 ]
  }
]
output [
  {
    name: "output0"
    data_type: TYPE_FP32
    dims: [ 25200, 85 ]
  }
]

ArgoHA avatar Jan 08 '23 11:01 ArgoHA

@ArgoHA , I am having the same problem here. Were you able to solve it ?

Traceback (most recent call last):
  File "detect.py", line 259, in <module>
    main(opt)
  File "detect.py", line 254, in main
    run(**vars(opt))
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "detect.py", line 113, in run
    model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup
  File "/usr/src/app/models/common.py", line 597, in warmup
    self.forward(im)  # warmup
  File "/usr/src/app/models/common.py", line 558, in forward
    y = self.model(im)
  File "/usr/src/app/utils/triton.py", line 60, in __call__
    inputs = self._create_inputs(*args, **kwargs)
  File "/usr/src/app/utils/triton.py", line 80, in _create_inputs
    input.set_data_from_numpy(value.cpu().numpy())
  File "/opt/conda/lib/python3.8/site-packages/tritonclient/grpc/__init__.py", line 1831, in set_data_from_numpy
    raise_error(
  File "/opt/conda/lib/python3.8/site-packages/tritonclient/utils/__init__.py", line 35, in raise_error
    raise InferenceServerException(msg=msg) from None
tritonclient.utils.InferenceServerException: got unexpected numpy array shape [1, 3, 640, 640], expected [-1, 3, 640, 640]

fabito avatar Jan 18 '23 00:01 fabito

@ArgoHA ,

I solved using this configuration:

name: "yolov5"
platform: "tensorrt_plan"
max_batch_size: 0
input [
  {
    name: "images"
    data_type: TYPE_FP32
    dims: [1, 3, 640, 640 ]
  }
]
output [
  {
    name: "output0"
    data_type: TYPE_FP32
    dims: [1, 25200, 85 ]
  }
]

fabito avatar Jan 18 '23 01:01 fabito