airflow
airflow copied to clipboard
ONNX Model Inference Operator
Description
ONNX (Open Neural Network Exchange) provides cross-platform compatibility
An operator that can run inference using ONNX models, ideal for deploying machine learning models in a standardized format can provide us with direct model invocation.
this can be solved using a pythonOperator ofc as onnxruntime can be executed with pythonruntime, but this can also be built into airflow to minimize work, a simple onnx operator structure would be something like:
import onnxruntime as ort
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime
def run_onnx_inference():
# Load the ONNX model
model_path = '/path/to/your/model.onnx'
session = ort.InferenceSession(model_path)
# Prepare input data
input_name = session.get_inputs()[0].name
input_data = {"your_input_key": your_input_data}
# Run inference
result = session.run(None, {input_name: input_data})
print(result)
# Define the DAG
with DAG(
dag_id='onnx_inference_dag',
start_date=datetime(2023, 1, 1),
schedule_interval='@once'
) as dag:
# Define the task
inference_task = PythonOperator(
task_id='onnx_inference_task',
python_callable=run_onnx_inference
)
Looking frwd to any suggestions.
Use case/motivation
A direct support of onnx with Airflow's DAG-based orchestration can manage the entire lifecycle of data processing and model inference in one place, providing a more cohesive and manageable workflow.
Related issues
No response
Are you willing to submit a PR?
- [ ] Yes I am willing to submit a PR!
Code of Conduct
- [X] I agree to follow this project's Code of Conduct
Thanks for opening your first issue here! Be sure to follow the issue template! If you are willing to raise PR to address this issue please do so, no need to wait for approval.
Hi , i have worked on this issue from past two days and I came up with a solution . I made certin chnage in the exisiting code and added the execute function inside the operation class that does the same work that your run_onnx_intefence() does . Please see this code and tell me if the code anywhere matches the frequency of your expections .
import onnxruntime as ort
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow import DAG
from datetime import datetime
class ONNXInferenceOperator(BaseOperator):
@apply_defaults
def __init__(self, model_path: str, input_data: dict, *args, **kwargs):
super(ONNXInferenceOperator, self).__init__(*args, **kwargs)
self.model_path = model_path
self.input_data = input_data
def execute(self, context):
session = ort.InferenceSession(self.model_path)
input_name = session.get_inputs()[0].name
result = session.run(None, {input_name: self.input_data})
self.log.info(f"Inference result: {result}")
return result
with DAG(
dag_id='onnx_inference_dag',
start_date=datetime(2023, 1, 1),
schedule_interval='@once',
catchup=False
) as dag:
inference_task = ONNXInferenceOperator(
task_id='onnx_inference_task',
model_path='/path/to/your/model.onnx',
input_data={"your_input_key": [[1.0, 2.0, 3.0]]}
)
inference_task
Hi I was expecting a reply from you , whenver you see this do let me know . Thank you .
Hi @Rohanberiwal @Faakhir30 Airflow doesn't have Onnx provider thus if you'd like to add it to Airflow you need to follow the protocol of adding new provider, Most of providers are managed by the community rather than by Airflow.
Yes sir , i will read that protocal and I will get back with a solution as soon as possible . Thank you for your reply .
ONNX Inference Operator for Apache Airflow
Description
The ONNXInferenceOperator is a custom operator designed for running inference using ONNX models within an Apache Airflow DAG. This operator leverages the onnxruntime library to load an ONNX model and perform inference on provided input data. The results of the inference are logged and returned.
Components
-
ONNXInferenceOperator: A custom Airflow operator that initializes with the path to the ONNX model and the input data. It performs inference in the
executemethod and logs the results. -
run_onnx_inference: A helper function that demonstrates how to run inference using the
onnxruntimelibrary directly within a PythonOperator. This function is provided as an alternative approach to using the custom operator. -
DAG Definition: Defines an Airflow DAG named
onnx_inference_dagthat schedules the inference task to run once.
Code
import onnxruntime as ort
from airflow import DAG
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.operators.python import PythonOperator
from datetime import datetime
class ONNXInferenceOperator(BaseOperator):
@apply_defaults
def __init__(self, model_path: str, input_data: dict, *args, **kwargs):
super(ONNXInferenceOperator, self).__init__(*args, **kwargs)
self.model_path = model_path
self.input_data = input_data
def execute(self, context):
session = ort.InferenceSession(self.model_path)
input_name = session.get_inputs()[0].name
result = session.run(None, {input_name: self.input_data})
self.log.info(f"Inference result: {result}")
return result
def run_onnx_inference():
model_path = '/path/to/your/model.onnx'
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
input_data = {"your_input_key": [[1.0, 2.0, 3.0]]}
result = session.run(None, {input_name: input_data})
print(result)
with DAG(
dag_id='onnx_inference_dag',
start_date=datetime(2023, 1, 1),
schedule_interval='@once',
catchup=False
) as dag:
inference_task = ONNXInferenceOperator(
task_id='onnx_inference_task',
model_path='/path/to/your/model.onnx',
input_data={"your_input_key": [[1.0, 2.0, 3.0]]}
)
Sir I would like to know more about the official process of gertting accpeted and work for the airflow , I have made the solution and read the protocal but where shoudl I have to raise a vote , so I can have a comversation with the people and they accept me .
Should I add teh lable of new use in the above proposed solution ?
It's all explained there - including (as of recently) links to examples where others attempted to propose their providers: https://github.com/apache/airflow/blob/main/PROVIDERS.rst#accepting-new-community-providers
Note - taht It's rather unll