onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

Bug: Converting from ONNX to ORT fails when setting Device=Direct ML [C++] [ONNX2ORT converter] [Direct ML]

Open gineshidalgo99 opened this issue 3 years ago • 24 comments

Describe the bug ONNX to ORT conversion works when device=CPU, but does not with Direct ML (exact same code)

Low level details:

  • I tried with 7 networks and it happens in all of them (including MNIST and ResNet)
  • I tried and this affects both v1.7.1 (the one we are using) and your very latest GitHub code.
  • ORT doesn't crash on conversion but rather later when loading/using those new ORT models, but when checking the ORT files, the ORT for CPU is similar in size to the ONNX file (~MBs), while the GPU one is only a few KBs, so definitely a bug on their converter. The crash it tells me is something about operators not implemented, but the ORT file is just too small (error in https://github.com/microsoft/onnxruntime/discussions/7931)

Also, all models run (and we checked they match the original PyTorch model accuracies) if loaded from ONNX and set to DML: What works:

  • ONNX file loaded, set to CPU and running inference on it
  • ONNX file loaded, set to GPU and running inference on it
  • ONNX file loaded, set to CPU, converted to ORT, loaded as ORT file, set to CPU, and running inference on it

What does not work:

  • ONNX file loaded, set to GPU, converted to ORT, loaded as ORT file, set to GPU, and running inference on it
  • ONNX file loaded, set to CPU, converted to ORT, loaded as ORT file, set to GPU, and running inference on it --> This one does not crash, but it is clearly running on CPU because its runtime timings are those of the CPU version (not the GPU version). So it seems that whatever session option was loaded for the ORT file is what it's used for it regardless of me trying to set it to another kind of device

Urgency Urgent --> It blocks ORT file deployment on DirectML networks. We have an internal deadline in August to release this project

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows 10
  • ONNX Runtime installed from (source or binary): Tried both, but we care mostly about the source one
  • ONNX Runtime version: v1.7.1 and also tested in latest GitHub code
  • Python version: None, using C++
  • Visual Studio version (if applicable): VS 2019 Professional
  • GCC/Compiler version (if compiling from source): VS 2019 Professional
  • CUDA/cuDNN version: None (DirectML)
  • GPU model and memory: Nvidia 3080

To Reproduce

  • Describe steps/code to reproduce the behavior.
// Conversion step
{
    // Set up ORT and create an environment
    Ort::InitApi();
    const char* const ModelRelativeFilePathCharPtr = TCHAR_TO_ANSI(*InModelRelativeFilePath);
    Environment = MakeUnique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, ModelRelativeFilePathCharPtr);
    Allocator = MakeUnique<Ort::AllocatorWithDefaultOptions>();
    SessionOptions = MakeUnique<Ort::SessionOptions>();
    if (Device == GPU)
    {
        SessionOptions->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
        OrtSessionOptionsAppendExecutionProvider_DML(*SessionOptions, 0);
    }
    else
    {
        SessionOptions->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
    }

    // Generate ORT file
    SessionOptions->SetOptimizedModelFilePath(*OutputORTOptimizedModelPath);
    Session = MakeUnique<Ort::Session>(*Impl->Environment, *FullModelFilePath, *SessionOptions);

    // Result --> ORT file on disk on OutputORTOptimizedModelPath, which is good if Device = CPU, but smaller than it should be if Device = GPU
}

// Running step
{
    // Same setting code

    // Load/run ORT file
    // Note the lack of "SetOptimizedModelFilePath()"
    Session = MakeUnique<Ort::Session>(*Impl->Environment, *FullModelFilePath, *SessionOptions);

   // Result --> ORT file working fine as long as it's on CPU, but crashing when it's DirectML giving the error shown in https://github.com/microsoft/onnxruntime/discussions/7931
}
  • Attach the ONNX model to the issue (where applicable) to expedite investigation. Here a zip file with 3 models (SqueezeNet, MNIST, and ResNet), and for each one: Original ONNX model, ORT model when device=CPU, and ORT model when device=DirectML: https://drive.google.com/file/d/1F13H3HW4PoEZqLBg2sJoXfBB-zcfQGwq

Expected behavior I expect both ORT files to approximately have the same size, and for the DirectML one not to crash when used later

gineshidalgo99 avatar Jul 20 '21 16:07 gineshidalgo99

Did you check if the original model (non-ORT format) runs on DML?

faxu avatar Jul 20 '21 20:07 faxu

Yes, I forgot to say that, all models run (and we checked they match the original PyTorch model accuracies) if loaded from ONNX and set to DML

What works:

  • ONNX file loaded, set to CPU and running inference on it
  • ONNX file loaded, set to GPU and running inference on it
  • ONNX file loaded, set to CPU, converted to ORT, loaded as ORT file, set to CPU, and running inference on it

What does not work:

  • ONNX file loaded, set to GPU, converted to ORT, loaded as ORT file, set to GPU, and running inference on it

We also tried this:

  • ONNX file loaded, set to CPU, converted to ORT, loaded as ORT file, set to GPU, and running inference on it --> This one does not crash, but it is clearly running on CPU because its runtime timings are those of the CPU version (not the GPU version). So it seems that whatever session option was loaded for the ORT file is what it's used for it regardless of me trying to set it to another kind of device

gineshidalgo99 avatar Jul 20 '21 20:07 gineshidalgo99

Did you try saving the optimized ONNX model as foo.onnx (where foo is the name of your model) without making a call to session_options.AddConfigEntry("session.save_model_format", "ORT"); and then running the saved model with DML?

pranavsharma avatar Jul 20 '21 20:07 pranavsharma

Please set the logging level to ORT_LOGGING_LEVEL_VERBOSE in Environment = MakeUnique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, ModelRelativeFilePathCharPtr); and attach the logs.

guoyu-wang avatar Jul 20 '21 21:07 guoyu-wang

From the ORT file uploaded, the one converted with DML has no initializers in the graph (the model has no weights at all), seems the initializer of the graph are cleared out by the DML EP before the saving to ORT format happens, probably this is the main reason why the execution fails.

guoyu-wang avatar Jul 20 '21 22:07 guoyu-wang

What's the reason for attempting to use the ORT file format in the GPU scenarios?

ORT format is targeting mobile/edge scenarios where binary size is critical, so the current expected usage is with CPU kernels and optionally things like the NNAPI or CoreML EP to utilize the NPU on a device. CUDA kernels are massive so any binary size saving from using the ORT fomat is meaningless. Not sure how large the DML kernels are, although I know there's no infrastructure setup to exclude them in a minimal build, so a build with DML enabled would include all the kernels and not just the required ones. Based on that, there doesn't seem to be a binary size benefit, so it's not clear why you'd want/need to use an ORT format model.

ONNX file loaded, set to CPU, converted to ORT, loaded as ORT file, set to GPU, and running inference on it --> This one does not crash, but it is clearly running on CPU because its runtime timings are those of the CPU version (not the GPU version). So it seems that whatever session option was loaded for the ORT file is what it's used for it regardless of me trying to set it to another kind of device

ORT format doesn't support changing the static kernel assigned to a node at runtime. If you generated the ORT format model with CPU enabled, it will only use CPU at runtime. It does allow dynamic kernels (e.g. NNAPI and CoreML) taking nodes at runtime (node is executed as a CoreML or NNAPI model so the static kernel assigned is ignored), but that doesn't seem to be applicable to your usage.

skottmckay avatar Jul 21 '21 00:07 skottmckay

PS: I will answer with @pranavsharma and @gwang-msft tests tomorrow (foo.onnx and ORT_LOGGING_LEVEL_VERBOSE)

Answering to @skottmckay: It is critical for us to be able to use a single and unified ORT API:

  • Using ORT files is important for us because we need to support Android, iOS, and CPU.
  • Supporting DML is also crucial because we need to support Windows/XBox machines.

Another hard requirement we have is that we cannot let the file sit on the hard-disk, we have to feed it to ORT on runtime. And ORT files/FlatBuffers are way simpler to serialize than protobuf/onnx ones.

  • I.e., it's extremely easy to send a std::vector to the ORT API and hack it to read it instead of a foo.ort file (we already have this working).
  • ONNX/Protobuf: We tried doing this with the ONNX file, but feeding a buffer of Protobuf data to the ORT API is not that easy at all, and the ORT API seems to open the onnx file in many places, it does not only read it as a vector as it does with the ORT file.

Given these 2 reasons, having ORT files working with DML is very important for us in the short term.

gineshidalgo99 avatar Jul 21 '21 01:07 gineshidalgo99

ONNX format files are supported on all platforms. It's just that the binary size of the ORT library will be bigger vs. a minimal build that only supports ORT format models (by a few MB). For that you get a lot more flexibility though, such as the ability to use CPU or GPU depending on what's available at runtime.

Can you provide more details on how you were trying to feed the ONNX format file at runtime? InferenceSession has an API where raw bytes can be provided, which can be used for both ONNX and ORT format models. Given that, I'm not quite following how 'the ORT API seems to open the onnx file in many places' given it's only seeing bytes and not a filename if that API is used.

https://github.com/microsoft/onnxruntime/blob/894fc828587c919d815918c4da6cde314e5d54ed/onnxruntime/core/session/inference_session.cc#L686

I did a quick test using the python API and it seemed to work fine with the ONNX format model being provided as bytes.

import onnxruntime as ort
import numpy as np

model_path = r'my_test_model.onnx'

so = ort.SessionOptions()
s = ort.InferenceSession(model_path, so)

# random input matching what the model requires
input_data = np.zeros((1, 5, 512, 867), dtype=np.float32)
inputs = { 'input': input_data }

# run with filename
o1 = s.run(None, inputs)

with open(model_path, 'rb') as infile:
    bytes = infile.read()
    # run with bytes
    s2 = ort.InferenceSession(bytes, so)
    o2 = s2.run(None, inputs)

    # this model produces a single output so compare the run via filename with the run with bytes
    print(np.array_equal(o1[0],o2[0]))

skottmckay avatar Jul 21 '21 01:07 skottmckay

@gineshidalgo99 Our public C API already provides a unified way to create sessions by passing the bytes associated with both ORT and ONNX models. Take a look at this function. This way you can use the ORT format models on ios and android and ONNX format on desktop/server.

pranavsharma avatar Jul 21 '21 02:07 pranavsharma

We are happy to try this solution, it'd solve the problem for us in the short term (getting Windows fully working)!

But we are working on C++, and I could not find any C++ example of this InferenceSession::Load(const void* model_data, int model_data_len). How can it be used from a onnx file in C++? Do I read it as a vector? As std::string? Or what exactly is what that void* takes? Any minimal C++ code snippet about how to turn the onnx file to that void* would highly help here!

(Less important in the short term) Also, about why we cared about ORT files and DML, we need a solution that also works for our custom GPU EP (for platforms like Nintendo Switch and PlayStation 5), where we also need to minimize build size in eg PS5. Given the ORT file issue with DML, we are concerned this might also occur if we create our own GPU EP for PS5/Nintendo, is this the case?

gineshidalgo99 avatar Jul 21 '21 02:07 gineshidalgo99

Example of reading bytes from file: https://github.com/microsoft/onnxruntime/blob/894fc828587c919d815918c4da6cde314e5d54ed/onnxruntime/test/shared_lib/test_model_loading.cc#L21-L31

The bytes are just passed directly when creating the inference session.

https://github.com/microsoft/onnxruntime/blob/894fc828587c919d815918c4da6cde314e5d54ed/onnxruntime/test/shared_lib/test_model_loading.cc#L41

We'll look into the DML issue as it should be possible to use that with an ORT format model.

skottmckay avatar Jul 21 '21 02:07 skottmckay

One example in our repo is here.

pranavsharma avatar Jul 21 '21 02:07 pranavsharma

Or you can look at this past issue, https://github.com/microsoft/onnxruntime/issues/6475#issuecomment-768787689

guoyu-wang avatar Jul 21 '21 02:07 guoyu-wang

Thanks to those last answers we were able to feed the ONNX buffer into ORT directly, which is a working workaround for us!

We will keep an eye to this post to know when the DML-ORT file issue is solved, as we'd need to switch to it once it's working, but we are no longer blocked.

Thanks for the quick answers and the great work!

gineshidalgo99 avatar Jul 22 '21 16:07 gineshidalgo99

Regarding the DML support, the DML EP has two different ways of handling parts of the graph. One is with statically registered kernels, and one is with dynamically created kernels. The static ones should work out-of-the-box with the ORT format. The dynamically registered ones however are making some changes to the graph earlier than expected, so parts of the graph aren't available to be saved in the ORT format model. As that's done somewhat unofficially (there's a const_cast to get access to initializers) we'd need to look into restructuring that to make sure that when we're creating the ORT format model that doesn't happen.

skottmckay avatar Aug 10 '21 21:08 skottmckay

POC for adding support for DML when using an ORT format model: https://github.com/microsoft/onnxruntime/compare/skottmckay/ORT_model_support_with_DML_EP

Technically we could create the ORT format model with just basic optimizations and DML disabled to not require the changes in the DML graph partitioning. At runtime, if DML was enabled it could still execute the same nodes.

skottmckay avatar Aug 11 '21 10:08 skottmckay

I think I have the same or highly related issue.

  1. onnx runtime 1.12 with DML ep
  2. squeezenet1.0-7.onnx from Microsoft git repo; filesize = 4,952,222 bytes
  3. SetOptimizedModelFilePath(thepath)
  4. session optimizes model and saves it to thepath with a file size = 3,756 bytes
  5. inference runs correctly
  6. shutdown

then

  1. onnx runtime 1.12 with DML ep
  2. the optimized squeezenet1.0-7.onnx from above step 4
  3. session fails with Load model from C:\**redacted**\squeezenet1.0-7.onnx failed:D:\a\_work\1\s\onnxruntime\core\graph\graph.cc:1203 onnxruntime::Graph::Graph This is an invalid model. Tensor does not have type information.

If not the same issue, then please tell me and I'll open a new issue

diablodale avatar Jul 29 '22 01:07 diablodale

@diablodale I am with him on this one. Getting the same error on multiple models and the resulting ONNX files are not viewable in netron. I also tried to set ORT_DISABLE_ALL optimizations in case ops are fused for DML but the Model is still broken.

gedoensmax avatar Sep 09 '22 10:09 gedoensmax

The DML EP makes some changes to the model during partitioning that are not really expected by ORT. Essentially it does a const_cast and steals initializers for memory usage reasons, but that means ORT doesn't have the initializers to write to the optimized file. @fdwr would your PR (still open I note) help with that?

@diablodale @gedoensmax can you elaborate on your use case where you want/need DML to be enabled when creating an optimized model vs. doing that at runtime?

skottmckay avatar Sep 11 '22 22:09 skottmckay

I have been looking into session creation time on ORT. For some models it is quite drastically decreased if the shape is known for esch tensor. With a fixed size input model and simplifying these shapes are usually saved - but sometimes only to some stage within the model. If i understand ort correctly it runs the model to „really“ know all shapes if the input has a fixed size.

I am aware that these models might get a complete shape inference with some graph surgeon magic. Nonetheless some applications habe either fixed size engines that are used on demand but have this problem (would be great to cache this to disk for later use). Or use a dynamic size model but if one size is used it is being used multiple times so that you might want to save this fixed shape ONNX file after first use. Something like TensorRT engine caching for DML. Or would the better way to save to ORT format ?

gedoensmax avatar Sep 11 '22 22:09 gedoensmax

@gedoensmax If you have a model with dynamic dimensions and want to make them fixed, you could use this tool: https://onnxruntime.ai/docs/reference/mobile/make-dynamic-shape-fixed.html

I don't quite understand how model load time would be affected by having fixed shapes. If anything, I would expect more optimizations to be possible when shapes are fixed.

I would suggest running the 'basic' level optimizations on the model with just the CPU EP enabled to do those optimizations ahead of time. They are not specific to any EP, only use official ONNX operators, and cover things like constant folding and common subexpression elimination.

Beyond the 'basic' level you get into EP specific optimizations which may involve compiling nodes or fusing nodes that will use a custom operator. Currently there's no general purpose way to save a compiled node like TensorRT engine caching does. An inference session is intended to be re-used though, so this cost during loading is not per-inference.

skottmckay avatar Sep 12 '22 01:09 skottmckay

@skottmckay 🤔 I should abandon that PR, as @sumitsays is working on a more complete solution after discussing with Cheng Tang about the EP interface refactor. Currently the DML EP fuses partitions of DML nodes into a single DML_GRAPH node, which is an IDMLOperator that contains all the operators for that partition, but if you attempt to reload the .ort graph containing a "DmlFusedGraph" node, ORT won't know how to map that to any operator because context is lost (there is no such ONNX operator with that name, and the internal subgraph only existed in memory).

However, beware that even after Sumit's changes, it will generally not be robust to optimize the graph with one GPU and run the same graph on a different GPU, as differences between GPU's (e.g. which data types are supported) could actually make a difference in the optimized graph. Replaying on the same machine, or on a specific device (e.g. gaming console) would be more robust.

fdwr avatar Sep 12 '22 19:09 fdwr

@diablodale @gedoensmax can you elaborate on your use case where you want/need DML to be enabled when creating an optimized model vs. doing that at runtime?

I create a DLL plugin for the Cycling74 Max runtime patching system. My customers are educators, researchers, artists, musicians, etc. I provide one onnx model for a specific use case plus the ability to run any onnx model. My DLL transforms in/outs between native Max data. My plugin allows running the model on the cpu, directml, cuda, or tensorRT providers with a single setting change. I hide all the technical complexities so my customers can focus on their art/research/education.

The Max environment is always running, it is a graphical hack/patch environment where nodes are connected by patchcords. Patchcords and nodes are reshaped/connected hundreds of times a day as customers experiment and try ideas. This realtime iteration necessitates caching and reuse. The time burden of running the onnx optimization process every time they connect a patchcord or click "go" hampers their creativity and kills their "flow".

I know when hardware, models, or settings change...therefore I can cache models after they go through the optimization process. I already do this successfully with the TensorRT provider. A similar ability with DirectML is desired and I attempted it with SetOptimizedModelFilePath() but ran into this same OP...the saved DirectML model is unusable.

diablodale avatar Sep 15 '22 00:09 diablodale

Unfortunately ORT doesn't have a way to general way to save a compiled node. The TensorRT EP is doing that via TensorRT's ability to save but AFAIK that is the only place that's possible. For CPU and CUDA you could save the fully optimized model as neither of those compile nodes. The saved model would contain internal operators that are specific to the CPU/CUDA EPs, but that should be fine for local caching.

skottmckay avatar Sep 15 '22 04:09 skottmckay

@diablodale / @gedoensmax:

  • This pending change allows exporting/reimporting the optimized model (recently enabled after Sumit's refactoring): https://github.com/microsoft/onnxruntime/pull/13913.
  • It will be in ORT 1.14.
  • Note the caveat remains that the same .ort file cannot reliably be replayed on different execution providers, and that the same file may not be replayable on the same execution provider across different GPU's, due to different graph partitioning assignments made based on GPU data type support.

fdwr avatar Dec 09 '22 05:12 fdwr

Got it, I've already code in place to invalidate a persisted optimized model if any config changes.

A question, in #13913 I saw the comment This transformer applies DML-specific fusions that go beyond what ORT offers by default. The following is some guessing... When we persist with setOptFilePath=true, it will not do the fusing of partitions of DML nodes into a single DML_GRAPH. It will instead persist a slightly less optimized model lacking that fuse. When this persisted model is loaded, will Ort do that final fuse optimization? Or, is this the tradeoff to have a faster load?

diablodale avatar Dec 09 '22 14:12 diablodale

When we persist with setOptFilePath=true, it will not do the fusing of partitions of DML nodes into a single DML_GRAPH.

@diablodale Correct, nodes will remain distinct operators (or fused operators).

It will instead persist a slightly less optimized model lacking that fuse.

Yes, it will have operator fusions (e.g. Conv + Relu -> ConvRelu), but not whole-graph-fusion.

image --> image

When this persisted model is loaded, will Ort do that final fuse optimization?

Yes, that final whole-graph-fusion will be done upon reload.

Or, is this the tradeoff to have a faster load?

That final fusion happens in either case, loading the original model or loading the pre-operator-fused model. Exporting to .onnx file and reloading, I noticed a time saving during session load of like 5-15% depending on the model, and run time is the same. Exporting to .ort file format and reloading, I noticed a substantial time saving in session load, from 2-7x depending on the model, but as enticing as that is, beware .ort is just recently enabled by https://github.com/microsoft/onnxruntime/pull/13913, and I can't yet vouch for it's robustness without further more exhuastive testing (I just tried it with a few models), because interaction with the DML EP might call new code paths. Also, we should verify whether this issue applies still: https://github.com/microsoft/onnxruntime/issues/13535.

Got it, I've already code in place to invalidate a persisted optimized model if any config changes.

Great. I'd also include the driver version too in your hash, just in case updating the driver changes registered data type support.

fdwr avatar Jan 11 '23 01:01 fdwr

Closing as resolved.

nums11 avatar Jul 27 '23 17:07 nums11