onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

ROCM EP convolution fails due to missing

Open dmnieto opened this issue 1 year ago • 6 comments

Describe the issue

"MIOpen Error: No invoker was registered for convolution forward." happens when trying to use any model for inference with convolution codes. This is because the caching system in the Update() call will check for previous used algorithms. But the ROCM API (contrary to CUDA) requires the algo search call as described here: https://rocmdocs.amd.com/projects/MIOpen/en/latest/MIOpen_Porting_Guide.html

To reproduce

This has been reproduced with immich latest version, on all the models being used.

Urgency

Relatively urgent as ROCM EP is broken.

Platform

Linux

OS Version

Ubuntu 22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Other / Unknown

Execution Provider Library Version

ROCMm - any version

dmnieto avatar Feb 19 '24 19:02 dmnieto

@PeixuanZuo Any idea, I think we didn't observe conv issue in SD.

@jeffdaily Could you please contact MIOpen dev for this. I think we introduced the manual cache for this since some older version of MIOpen, but it seems we need to remove the manual caching logic for this again?

cloudhan avatar Feb 20 '24 10:02 cloudhan

The logic in Immich for the model training is quitesimple, I have done several tests and it does not look like the bug could be specific there ( I have a bug open in parallel to track the integration of the rocm EP)I think this just makes be a case on landing in an untested code path… it does not look like the rocm ep gets used that much. Regardless, the porting guide makes clear we can’t cache the algos…On Feb 20, 2024, at 02:42, cloudhan @.***> wrote: @PeixuanZuo Any idea, I think we didn't observe conv issue in SD. @jeffdaily Could you please contact MIOpen dev for this. I think we introduced the manual cache for this since some older version of MIOpen, be it seems we need to remove the manual caching logic for this again?

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>

dmnieto avatar Feb 20 '24 13:02 dmnieto

I have reached out to our MIOpen team.

jeffdaily avatar Feb 20 '24 16:02 jeffdaily

Thanks Jeff,

In the code review there are mentions that tensorflow uses algo catching. I am going to instrument a bit more the testcode to see if there is any reason that we could end up in a sitation where we launch the convolutions without calling the findalgo first... Maybe something is broken on the caching itself instead... but so far I have not been successful in identifying the corner case.

On Tue, 20 Feb 2024 at 08:25, Jeff Daily @.***> wrote:

I have reached out to our MIOpen team.

— Reply to this email directly, view it on GitHub https://github.com/microsoft/onnxruntime/issues/19566#issuecomment-1954582705, or unsubscribe https://github.com/notifications/unsubscribe-auth/AANL3WMQGLJUUE5A5DXN3GLYUTE6LAVCNFSM6AAAAABDP653KWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNJUGU4DENZQGU . You are receiving this because you authored the thread.Message ID: @.***>

dmnieto avatar Feb 20 '24 16:02 dmnieto

Can you provide more details? Docker container? With reproducer steps?

jeffdaily avatar Feb 20 '24 17:02 jeffdaily

You can see the discussion (including dockerfile) here https://github.com/immich-app/immich/discussions/7169You likely don’t want to go through the mess of building the project, I’ll provide a more concise repro setup. But due completion sake:The issue is happening when setting up the machine_learning docker after building onnxruntime with rocm 6.0.2 against the main branch. The error happens on the second call to the face detection model (buffalo_l) when instantiating the predict function.I will grab de logs in a few days from an earlier run On Feb 20, 2024, at 09:49, Jeff Daily @.***> wrote: Can you provide more details? Docker container? With reproducer steps?

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>

dmnieto avatar Feb 20 '24 21:02 dmnieto

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

github-actions[bot] avatar Mar 22 '24 15:03 github-actions[bot]

The root cause is the cache key is x shape (but not a combination of x shape + w shape) and w shape is not constant. When there is multiple threading:

  • thread 1 uses w shape w1 and x shape x1, it runs algo search then adds a cache of <x1, algo1> for w1.
  • thread 2 uses w shape w2 and x shape x1, it clear cache since w shape change; right after this, thread 1 insert cache <x1, algo1> before thread 2 looking up cache. Then thread 2 looks up cache with key x1 and found <x1, algo1> inserted by thread 1, then choose algo1, but algo1 cannot apply to w2, so thread 2 raise runtime error.

The solution is that we need to use x shape + w shape as key for cache, and never clear cache; and add mutex to guard the cache to make sure there is only one thread is looking up or updating the cache.

Current cache is applied to one node so Conv algo search will need multiple times when a model has multiple Conv nodes. A ideal solution is to have a global cache (like PyTorch code), and use all conv parameters (including device id) as key, that could avoid duplicated algo search.

tianleiwu avatar Jul 11 '24 04:07 tianleiwu

That can make a lot of sense, since the issues were discovered when running two different models in parallel. I can try that on my setup.

On Wed, 10 Jul 2024 at 21:38, Tianlei Wu @.***> wrote:

The root cause is the cache key is x shape (but not a combination of x shape + w shape). When there is multiple threading:

  • thread 1 use w shape w1 and x shape x1, it adds a cache of <x1, algo1>, use algo1 to run
  • thread 2 use w shape w2 and x shape x1, it lookup cache and found <x1, algo1>, then use algo1 run, but algo1 cannot apply to w2, so it raise error.

The solution is that we need to use x shape + w shape as key for cache, and never clear cache; and add mutex to guard the cache to make sure there is only one thread is looking up or updating the cache.

— Reply to this email directly, view it on GitHub https://github.com/microsoft/onnxruntime/issues/19566#issuecomment-2222013184, or unsubscribe https://github.com/notifications/unsubscribe-auth/AANL3WIUIIR2KKFBZ3RCNC3ZLYD5FAVCNFSM6AAAAABDP653KWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMRSGAYTGMJYGQ . You are receiving this because you authored the thread.Message ID: @.***>

dmnieto avatar Jul 11 '24 16:07 dmnieto