optimum
optimum copied to clipboard
Exporting tinyllama-1.1b using onnxruntime bf16 crashes
System Info
System information:
Container is Debian12 (mambaorg/micromamba)
Host is RHEL9 / ppc64le
$ cat /etc/os-release
PRETTY_NAME="Debian GNU/Linux 12 (bookworm)"
NAME="Debian GNU/Linux"
VERSION_ID="12"
VERSION="12 (bookworm)"
VERSION_CODENAME=bookworm
ID=debian
HOME_URL="https://www.debian.org/"
SUPPORT_URL="https://www.debian.org/support"
BUG_REPORT_URL="https://bugs.debian.org/"
$ uname -a
Linux b8e04f1032bc 5.14.0-362.18.1.el9_3.ppc64le #1 SMP Mon Jan 29 03:48:20 PST 2024 ppc64le GNU/Linux
Python, Optimum & PyTorch version:
$ python3 -V
Python 3.10.9
$ pip3 list installed | grep optimum
optimum 1.18.1
$ micromamba list pytorch
List of packages in environment: "/opt/conda"
Name Version Build Channel
─────────────────────────────────────────────────────────────
_pytorch_select 1.0 cpu_2 rocketce
pytorch-base 2.0.1 cpu_py310_pb4.21.12_1 rocketce
pytorch-cpu 2.0.1 py310_1 rocketce
(base) root@b8e04f1032bc:/tmp# micromamba list onnx
List of packages in environment: "/opt/conda"
Name Version Build Channel
──────────────────────────────────────────────────────────────────
onnx 1.13.1 h25d5be3_py310_pb4.21.12_1 rocketce
onnxruntime 1.15.1 hd867603_cpu_py310_pb4.21.12_1 rocketce
Who can help?
@michaelbenayoun @JingyaHuang @echarlaix
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction (minimal, reproducible, runnable)
Converting to fp32 works without issues, fp16 is not possible since I'm on a CPU only system and bf16 throws the following error:
$ optimum-cli export onnx --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 tinyllama-1b_onnx/ --dtype bf16
[...]
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
================ Diagnostic Run torch.onnx.export version 2.0.1 ================
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
Saving external data to one file...
Traceback (most recent call last):
File "/opt/conda/bin/optimum-cli", line 8, in <module>
sys.exit(main())
File "/opt/conda/lib/python3.10/site-packages/optimum/commands/optimum_cli.py", line 163, in main
service.run()
File "/opt/conda/lib/python3.10/site-packages/optimum/commands/export/onnx.py", line 261, in run
main_export(
File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/__main__.py", line 351, in main_export
onnx_export_from_model(
File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/convert.py", line 1157, in onnx_export_from_model
_, onnx_outputs = export_models(
File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/convert.py", line 768, in export_models
export(
File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/convert.py", line 902, in export
config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype)
File "/opt/conda/lib/python3.10/site-packages/optimum/exporters/onnx/base.py", line 306, in fix_dynamic_axes
session = InferenceSession(model_path.as_posix(), providers=providers, sess_options=session_options)
File "/opt/conda/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 383, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/opt/conda/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 424, in _create_inference_session
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from tinyllama-1b_onnx/model.onnx failed:This is an invalid model. Type Error: Type 'tensor(bfloat16)' of input parameter (/model/Constant_34_output_0) of operator (Where) in node (/model/Where_3) is invalid.
Expected behavior
Convert the model properly to bf16
Thank you @mgiessing! It is possible that the ONNX model is valid, but ORT is missing some operators for bf16. It can halso be a bug, I will have a look shortly.
Thank you for having a look - this also happened on my Mac M1 with a more recent ORT version (v1.17.1) and also with a different model (deepset/roberta-base-squad2)
@mgiessing Where (used in https://github.com/huggingface/transformers/blob/caa5c65db1f4db617cdac2ad667ba62edf94dd98/src/transformers/models/llama/modeling_llama.py#L1086) is not implemented for BF16 dtype in ORT https://github.com/microsoft/onnxruntime/blob/v1.17.1/docs/OperatorKernels.md
However it is valid in ONNX standard: https://github.com/onnx/onnx/blob/main/docs/Operators.md#where
I suggest you to open a feature request in ONNX Runtime repository concerning this to add the support. In the meantime, we could patch Transformers code for this to work in BF16 (avoid the Where op in bf16).
See as well https://github.com/huggingface/optimum/issues/1720#issuecomment-1963838333 that is related and that you are likely to hit as well
If you are using optimum installed from source, a warning is displayed about this:
Exporting the model LlamaForCausalLM in bfloat16 float dtype. After the export, ONNX Runtime InferenceSession with CPU/CUDA execution provider likely does not implement all operators for the bfloat16 data type, and the loading is likely to fail.
Thanks for having a look at that :) I'll try to open a request starting next week to address that issue!