server
server copied to clipboard
Large model output is copied when working from a BLS
Description When querying an onnx model from a python BLS, performance seems to suffer with large outputs. We suspect the output data is being copied.
Triton Information 22.06 + patch for python backend with code from https://github.com/triton-inference-server/python_backend/pull/179
Are you using the Triton container or did you build it yourself? Triton container
Describe the models (framework, inputs, outputs), ideally include the model configuration file (if using an ensemble include the model configuration file for that as well).
We have an onnx model that outputs a large amount of data. We run it on an A100 GPU in GCP. At first it seemed this data was being outputed on cpu instead of GPU. Thanks to the discussion here we were able to apply a patch that really helped reduce the inference times.
Now the data is being outputed on the gpu, as expected.
However, it seems the performance can still be improved.
inference_request = pb_utils.InferenceRequest(
model_name=target_model,
requested_output_names=output_names,
inputs=inputs,
)
inference_response = inference_request.exec()
Running this with all the lrge output names results in ~45ms inference time. However, removing the "heavy" output names and leaving just the one that's less data results in ~25ms inference times.
Note that these outputs are still being calculated, they are just not passed back to the BLS script.
This leads us to believe that some copy of the outputs is still being made.
Any advice/assistance will be much appreciated.
Thanks
An update on this - we've been compiling from source and adding debug prints to figure out where the time is spent.
We see a lot of time spent here -
https://github.com/triton-inference-server/python_backend/blob/0d24fda0d2459536e1f4e8d5c368aaeda8ce838f/src/infer_request.cc#L432
We have many large inputs on GPU, and it seems they are being copied before being sent downstream to the other backend.
Anything can be done to prevent this copy?
Thanks
@Tabrizian Can this be resolved by the enhancement we were talking about? Sharing tensor data from system shared memory itself among the BLS/ensemble models? This is GPU memory however, so cuda shared memory in this case?
@tanmayv25 @Tabrizian Important thing I forgot to mention - the inputs I supply are actually outputs from a previous inference on the same model, so they are probably already on the shared memory.
Maybe some optimization can take place in that case?
We keep debugging and seeing some more copies being made.
This time it is inside the onnxruntime backend.
We've tracked some time lost in this line
https://github.com/triton-inference-server/onnxruntime_backend/blob/52b4e6204ad9ced804e43bc1a4a2626c61fef971/src/onnxruntime.cc#L2237
Which ultimately leads to this line
https://github.com/triton-inference-server/backend/blob/0bbbb3e8bfae32fa86e68235387dd9161488669a/src/backend_output_responder.cc#L333
I think this results in a copy being made. Probably not a big deal for small outputs, but it accumulates for large outputs.
In our case, the total inference time was 22ms, out of which 10ms was spent on these lines.
This adds up when you run 50 such inferences in a tight loop from a BLS script.
Any advice on this? Is there something we can do to prevent these copies?
We have many large inputs on GPU, and it seems they are being copied before being sent downstream to the other backend.
Can you fully describe your inference pipeline? Do you have a single Python model where you perform BLS in it? Triton backend API requires the outputs provided by the backend to be copied to a separate buffer. Currently, there is no way to avoid this copy in Python/ONNX backends.
Sharing tensor data from system shared memory itself among the BLS/ensemble models?
This optimization is already applied in the Python backend. If you have a tensor that is already in shared memory, it is not going to be copied again.
We have many large inputs on GPU, and it seems they are being copied before being sent downstream to the other backend.
The input/output copy in Python backend looks like below:
- Copy the inputs to shared memory.
- Retrieve the inputs from shared memory.
- Copy the outputs to shared memory.
- Copy the outputs from shared memory to output buffers provided by Triton.
We discussed an optimization that might be possible to implement in Triton for steps 1 and 3 but it has not been scheduled yet.
I supply are actually outputs from a previous inference on the same model, so they are probably already on the shared memory.
In this case, I think you might also be interested in the implicit state management feature. Unfortunately, this has not been implemented in Python backend yet but is available in ONNX and TensorRT backends. If you are able to use those backends, you can leverage this feature for better performance. If you store the input in an internal attribute and the tensor is already stored in shared memory, Python backend will not make another copy.
@amircodota I was curious what profiling tool did you use for measuring the performance and identifying the bottlenecks. Looks like you have found all the right spots.
@Tabrizian Thank you for the detailed answer.
No profiler was used. Just plain old "printf debugging" :-)
As a workaround for now, we will try to implement a python backend that uses ORT for python to load the onnx model directly in the python backend.
This approach has some drawbacks, but it seems like a reasonable compromise for our use case for now.
Please let us know if you guys decide to implement the above optimization. We can also try it out before release, and let you guys know if it solved the issues we have in our use case.
@Tabrizian
This optimization is already applied in the Python backend. If you have a tensor that is already in shared memory, it is not going to be copied again.
I'm confused as to how to ensure data goes into cuda shared memory. Or is this the default setting?
Closing this issue due to lack of activity. If this issue needs follow-up, please let us know and we can reopen it for you.
@zhaozhiming37 you can read up about how to use CUDA shared memory here: https://github.com/triton-inference-server/client#cuda-shared-memory and https://github.com/triton-inference-server/client#download-docker-image-from-ngc
You can find examples here: https://github.com/triton-inference-server/client/blob/main/src/python/examples/simple_http_shm_client.py
When you are sending new requests you set how the memory is stored (as shown in the exmaples)