kedro-plugins
kedro-plugins copied to clipboard
[spike] Investigate support for TaskFlowAPI in kedro-airflow
Description
Is your feature request related to a problem? A clear and concise description of what the problem is: "I'm always frustrated when ..."
Since Airflow 2.0, a simpler TaskFlowAPI for DAGs is released as an alternative with the Operator
API. At the moment kedro-airflow
supports the Operator
, but it's good to keep an eye on it.
https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html
Context
Why is this change important to you? How would you use it? How can it benefit other users?
Possible Implementation
(Optional) Suggest an idea for implementing the addition or change.
Possible Alternatives
(Optional) Describe any alternative solutions or features you've considered.
Apparently both approaches can coexist, so TaskFlow doesn't aim to be an alternative https://www.astronomer.io/blog/apache-airflow-taskflow-api-vs-traditional-operators/
Outcome: Example of how this API would look like, plus some investigation on what's the sentiment of the community towards this new API (with respect to more modern Python operators, such as the virtual env operator, name TBC).
I attempted to update the spaceflights-pandas DAG from its current form to the TaskFlow API, as seen in this commit: https://github.com/DimedS/kedro-taskFlowAPI/commit/98b69f575a80332bc67868f8d1796346a90860e9 It works, but the result looks not so beautiful. I believe we should redesign some aspects. With the TaskFlow API, it seems we no longer need the KedroOperator, and we should also improve something in the way how we describe the execution order with TaskFlowAPI.
TaskFlow API
TaskFlow API was introduced in Apache Airflow v2 and onwards. It's an alternative way to define your DAG files for orchestration by Airflow.
Basic Structure
Example TaskFlow DAG
from airflow.decorators import dag, task
@dag(
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["example"],
)
def tutorial_taskflow_api():
"""
### TaskFlow API Tutorial Documentation
This is a simple data pipeline example which demonstrates the use of
the TaskFlow API using three simple tasks for Extract, Transform, and Load.
Documentation that goes along with the Airflow TaskFlow API tutorial is
located
[here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html)
"""
@task()
def extract():
"""
#### Extract task
A simple Extract task to get data ready for the rest of the data
pipeline. In this case, getting data is simulated by reading from a
hardcoded JSON string.
"""
data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
order_data_dict = json.loads(data_string)
return order_data_dict
@task(multiple_outputs=True)
def transform(order_data_dict: dict):
"""
#### Transform task
A simple Transform task which takes in the collection of order data and
computes the total order value.
"""
total_order_value = 0
for value in order_data_dict.values():
total_order_value += value
return {"total_order_value": total_order_value}
@task()
def load(total_order_value: float):
"""
#### Load task
A simple Load task which takes in the result of the Transform task and
instead of saving it to end user review, just prints it out.
"""
print(f"Total order value is: {total_order_value:.2f}")
order_data = extract()
order_summary = transform(order_data)
load(order_summary["total_order_value"])
tutorial_taskflow_api()
As compared to the same code in the traditional DAG format :
Same Example in the traditional API
import json
from datetime import datetime
from airflow import DAG
from airflow.models.baseoperator import chain
from airflow.operators.python import PythonOperator
def extract(ti=None, **kwargs):
"""
Pushes the estimated population (in millions) of
various cities into XCom for the ETL pipeline.
Obviously in reality this would be fetching this
data from some source, not hardcoded values.
"""
sample_data = {"Tokyo": 3.7, "Jakarta": 3.3, "Delhi": 2.9}
ti.xcom_push("city_populations", json.dumps(sample_data))
def transform(ti=None, **kwargs):
"""
Pulls the provided raw data from XCom and pushes
the name of the largest city in the set to XCom.
"""
raw_data = ti.xcom_pull(task_ids="extract", key="city_populations")
data = json.loads(raw_data)
largest_city = max(data, key=data.get)
ti.xcom_push("largest_city", largest_city)
def load(ti=None, **kwargs):
"""
Loads and prints the name of the largest city in
the set as determined by the transform.
"""
largest_city = ti.xcom_pull(task_ids="transform", key="largest_city")
print(largest_city)
with DAG(
dag_id="city_pop_etl_pythonoperator",
schedule=None,
start_date=datetime(2021, 1, 1),
catchup=False,
tags=["example"],
) as dag:
extract_task = PythonOperator(
task_id="extract",
python_callable=extract,
)
transform_task = PythonOperator(
task_id="transform",
python_callable=transform,
)
load_task = PythonOperator(
task_id="load",
python_callable=load,
)
chain(
extract_task,
transform_task,
load_task,
)
- You can decorate python functions with
@task
decorators - The DAG file is decorated with
@dag
decorator
Kedro Spaceflights DAGs example
Consider the Spaceflights tutorial (only the data processing part to keep the code snippets short):
Just python code
If we wanted to perform data processing steps on the companies
, shuttles
, and the reviews
data from scratch (without using Kedro), we could simply define python functions as @task
s and pass data between them -
Data processing DAG without Kedro
from airflow.decorators import dag, task
import pandas as pd
import logging
def _is_true(x: pd.Series) -> pd.Series:
return x == "t"
def _parse_percentage(x: pd.Series) -> pd.Series:
x = x.str.replace("%", "")
x = x.astype(float) / 100
return x
def _parse_money(x: pd.Series) -> pd.Series:
x = x.str.replace("$", "").str.replace(",", "")
x = x.astype(float)
return x
@task
def preprocess_companies(companies: pd.DataFrame) -> pd.DataFrame:
"""Preprocesses the data for companies.
Args:
companies: Raw data.
Returns:
Preprocessed data, with `company_rating` converted to a float and
`iata_approved` converted to boolean.
"""
companies["iata_approved"] = _is_true(companies["iata_approved"])
companies["company_rating"] = _parse_percentage(companies["company_rating"])
logging.info("Preprocessing companies data")
return companies
@task
def preprocess_shuttles(shuttles: pd.DataFrame) -> pd.DataFrame:
"""Preprocesses the data for shuttles.
Args:
shuttles: Raw data.
Returns:
Preprocessed data, with `price` converted to a float and `d_check_complete`,
`moon_clearance_complete` converted to boolean.
"""
shuttles["d_check_complete"] = _is_true(shuttles["d_check_complete"])
shuttles["moon_clearance_complete"] = _is_true(shuttles["moon_clearance_complete"])
shuttles["price"] = _parse_money(shuttles["price"])
logging.info("Preprocessing shuttles data")
return shuttles
@task
def create_model_input_table(
shuttles: pd.DataFrame, companies: pd.DataFrame, reviews: pd.DataFrame
) -> pd.DataFrame:
"""Combines all data to create a model input table.
Args:
shuttles: Preprocessed data for shuttles.
companies: Preprocessed data for companies.
reviews: Raw data for reviews.
Returns:
Model input table.
"""
logging.info("Preprocessing model input table data")
rated_shuttles = shuttles.merge(reviews, left_on="id", right_on="shuttle_id")
rated_shuttles = rated_shuttles.drop("id", axis=1)
model_input_table = rated_shuttles.merge(
companies, left_on="company_id", right_on="id"
)
model_input_table = model_input_table.dropna()
return model_input_table
shuttle_data = pd.read_excel("shuttles.xlsx")
company_data = pd.read_csv("companies.csv")
reviews = pd.read_csv("reviews.csv")
@dag(dag_id="space_no_kedro")
def space_dag():
companies = preprocess_companies(company_data)
shuttles = preprocess_shuttles(shuttle_data)
create_model_input_table(shuttles, companies, reviews)
space_dag()
Airflow can infer the order of these tasks from the inputs and outputs of the tasks -
@dag(dag_id="space_no_kedro")
def space_dag():
companies = preprocess_companies(company_data)
shuttles = preprocess_shuttles(shuttle_data)
create_model_input_table(shuttles, companies, reviews)
With Kedro with the traditional API
Now, if we were already using Kedro, the steps to deploy the Spaceflights tutorial are currently as follows -
- Make sure your project doesn't have any
MemoryDataset
s - add explicit entries for the memory datasets or add a default dataset factory pattern to the catalog - Generate a DAG using
kedro-airflow
withkedro airflow create
- Move the generated DAG to
<AIRFLOW_HOME>/dags
- Package and install the spaceflights project in the airflow environment
- Point to the
conf_source
(in the DAG) or move/copy theconf/
folder to airflow directory
This is the DAG generated by kedro airflow create
-
Traditional DAG generated by `kedro-airflow`
from __future__ import annotations
from datetime import datetime, timedelta
from pathlib import Path
from airflow import DAG
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from kedro.framework.session import KedroSession
from kedro.framework.project import configure_project
class KedroOperator(BaseOperator):
@apply_defaults
def __init__(
self,
package_name: str,
pipeline_name: str,
node_name: str | list[str],
project_path: str | Path,
env: str,
*args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.package_name = package_name
self.pipeline_name = pipeline_name
self.node_name = node_name
self.project_path = project_path
self.env = env
def execute(self, context):
configure_project(self.package_name)
with KedroSession.create(self.project_path, env=self.env) as session:
if isinstance(self.node_name, str):
self.node_name = [self.node_name]
session.run(self.pipeline_name, node_names=self.node_name)
# Kedro settings required to run your pipeline
env = "local"
pipeline_name = "__default__"
project_path = Path.cwd()
package_name = "space"
# Using a DAG context manager, you don't have to specify the dag property of each task
with DAG(
dag_id="space-project",
start_date=datetime(2023,1,1),
max_active_runs=3,
# https://airflow.apache.org/docs/stable/scheduler.html#dag-runs
schedule_interval="@once",
catchup=False,
# Default settings applied to all tasks
default_args=dict(
owner="airflow",
depends_on_past=False,
email_on_failure=False,
email_on_retry=False,
retries=1,
retry_delay=timedelta(minutes=5)
)
) as dag:
tasks = {
"preprocess-companies-node": KedroOperator(
task_id="preprocess-companies-node",
package_name=package_name,
pipeline_name=pipeline_name,
node_name="preprocess_companies_node",
project_path=project_path,
env=env,
),
"preprocess-shuttles-node": KedroOperator(
task_id="preprocess-shuttles-node",
package_name=package_name,
pipeline_name=pipeline_name,
node_name="preprocess_shuttles_node",
project_path=project_path,
env=env,
),
"create-model-input-table-node": KedroOperator(
task_id="create-model-input-table-node",
package_name=package_name,
pipeline_name=pipeline_name,
node_name="create_model_input_table_node",
project_path=project_path,
env=env,
),
}
tasks["create-model-input-table-node"] >> tasks["split-data-node"]
tasks["preprocess-companies-node"] >> tasks["create-model-input-table-node"]
tasks["preprocess-shuttles-node"] >> tasks["create-model-input-table-node"]
To run Kedro nodes on Airflow, we define a KedroOperator
which is a subclass of the BaseOperator
. It creates a Session and runs specific nodes for each call -
class KedroOperator(BaseOperator):
@apply_defaults
def __init__(
self,
package_name: str,
pipeline_name: str,
node_name: str | list[str],
project_path: str | Path,
env: str,
*args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.package_name = package_name
self.pipeline_name = pipeline_name
self.node_name = node_name
self.project_path = project_path
self.env = env
def execute(self, context):
configure_project(self.package_name)
with KedroSession.create(self.project_path, env=self.env) as session:
if isinstance(self.node_name, str):
self.node_name = [self.node_name]
session.run(self.pipeline_name, node_names=self.node_name)
And then we define tasks as a dict of node name to it's KedroOperator
object:
tasks = {
"preprocess-companies-node": KedroOperator(
...
),
"preprocess-shuttles-node": KedroOperator(
...
),
"create-model-input-table-node": KedroOperator(
...
),
}
And finally, we define the order of execution of these tasks -
tasks["create-model-input-table-node"] >> tasks["split-data-node"]
tasks["preprocess-companies-node"] >> tasks["create-model-input-table-node"]
tasks["preprocess-shuttles-node"] >> tasks["create-model-input-table-node"]
Few things to note:
- We assume that the intermediate datasets are saved and loaded from file because each Session runs independently
- We need to define the task order
With Kedro using the TaskFlow API
I tried to convert the traditional spaceflights DAG to the TaskFlow API
TaskFlow format DAG
from airflow.decorators import dag, task
from datetime import datetime, timedelta
from pathlib import Path
from kedro.framework.session import KedroSession
from kedro.framework.project import configure_project
def run_kedro_node(package_name, pipeline_name, node_name, project_path, env):
configure_project(package_name)
with KedroSession.create(project_path, env=env) as session:
return session.run(pipeline_name, node_names=[node_name])
# Kedro settings required to run your pipeline
env = "local"
pipeline_name = "__default__"
project_path = Path.cwd()
package_name = "space"
@dag(
dag_id="space_tf",
start_date=datetime(2023,1,1),
max_active_runs=3,
# https://airflow.apache.org/docs/stable/scheduler.html#dag-runs
schedule_interval="@once",
catchup=False,
# Default settings applied to all tasks
default_args=dict(
owner="airflow",
depends_on_past=False,
email_on_failure=False,
email_on_retry=False,
retries=1,
retry_delay=timedelta(minutes=5)
)
)
def kedro_dag():
@task(task_id="preprocess-companies-node")
def preprocess_companies_node():
run_kedro_node(package_name, pipeline_name, "preprocess_companies_node", project_path, env)
@task(task_id="preprocess-shuttles-node")
def preprocess_shuttles_node():
run_kedro_node(package_name, pipeline_name, "preprocess_shuttles_node", project_path, env)
@task(task_id="create-model-input-table-node")
def create_model_input_table_node():
run_kedro_node(package_name, pipeline_name, "create_model_input_table_node", project_path, env)
preprocess_shuttles_node() >> create_model_input_table_node()
preprocess_companies_node() >> create_model_input_table_node()
kedro_dag()
Key Differences + Similarities
- Get rid of
KedroOperator
for a function that does the same thing asKedroOperator.execute()
def run_kedro_node(package_name, pipeline_name, node_name, project_path, env):
configure_project(package_name)
with KedroSession.create(project_path, env=env) as session:
return session.run(pipeline_name, node_names=[node_name])
- Each task is a function which calls
run_kedro_node
with the specific node name:
@task(task_id="preprocess-companies-node")
def preprocess_companies_node():
run_kedro_node(package_name, pipeline_name, "preprocess_companies_node", project_path, env)
- Still have to define a task order because we don't pass
MemoryDatasets
between tasks
preprocess_shuttles_node() >> create_model_input_table_node()
preprocess_companies_node() >> create_model_input_table_node()
- There is a call to the dag function at the end
kedro_dag()
Benefits of TaskFlow DAGs
These are the benefits that the new TaskFlow API offers -
Reduce boilerplate code
TaskFlow API greatly benefits you if you've been using the PythonOperator
in your code, you can just decorate your python functions with @task
. In the case of Kedro projects, it can still reduce the boilerplate code.
- We don't need a
KedroOperator
just to create and run a session, it can be a simple python function - The lines of code go down from 119 to 67 for the spaceflights tutorial for example
- Not much of a difference though, each task still needs a function + we still need to define the task order
Intuitive data transfer between DAGs
If you are working with python functions as tasks, TaskFlow API makes it super easy to pass data between them. It would have to be done through xcom_push
or xcom_pull
earlier with the traditional DAGs.
This is not directly relevant to the way Kedro works right now since we expect users to not have MemoryDatasets
in their project - they can add explicit entries to their catalog or add a default catch-all dataset factory pattern.
NOTE (Out of scope) : It is possible to pass datasets between different tasks. But it isn't very easy to inject data into a KedroSession
(related ticket https://github.com/kedro-org/kedro/issues/2169) See the comment below 👇🏾
No need to define task order
This again is not relevant to the way Kedro works right now, this is because in the TaskFlow API infers the orders of the tasks from the inputs and outputs between the tasks. For Kedro, we create individual KedroSession
s for each node, there is no data that passes between them. We still need to define task orders
Conclusion
- The benefits of the TaskFlow API are not very relevant to the way we define Kedro DAGs, it does reduce the boilerplate code a little bit.
- It is possible to use the traditional + the TaskFlow API together
- It is possible to leverage the TaskFlow API to pass
MemoryDatasets
between tasks - with some modifications to theKedroSession
. See the comment following this for further discussion.
Decisions
- Should we offer this at all?: TaskFlow API is not replacing the traditional DAGs, it's just an alternative way to define the DAGs
If yes ^:
- Fully?
- Replace the current DAG jinja with the TaskFlow one? Yes/No
- Partially?
- Modify the current DAG jinja to incorporate some elements of the TaskFlow API (like
@dag
decorator) - Offer it as an option
- If option, what would be the design?
--api=<traditional/taskflow>
- If option, what would be the design?
- Modify the current DAG jinja to incorporate some elements of the TaskFlow API (like
Extra: Leveraging the TaskFlow API
One of the benefits of TaskFlow API is that it makes it simpler to pass data between tasks without relying on XCom
.
Currently we recommend users to make sure their projects do not rely on MemoryDataset
s - either put in explicit entries for the memory datasets or a catch-all pattern in their catalog files.
The ways to get around this are -
- Grouping nodes that share
MemoryDatasets
between them into one task.
This is now possible with #241 This also has the additional benefit of reducing the overhead on distributed systems of mapping each node to a task. We can also expand the grouping strategies to address the pain point of mapping mentioned in https://github.com/kedro-org/kedro/issues/3094
- Leveraging the TaskFlow API (with some changes on Kedro side)
We can leverage TaskFlowAPI to make it possible to pass along intermediate datasets between tasks without the overhead of saving and loading them from file.
Link to the DAGs - https://github.com/ankatiyar/kedro-airflow-dags/blob/main/space_dag_traditional_memoryds.py & https://github.com/ankatiyar/kedro-airflow-dags/blob/main/space_dag_tf_memoryds.py
DAG that allows for `MemoryDataset` passing
from __future__ import annotations
from datetime import datetime, timedelta
from pathlib import Path
from airflow import DAG
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from kedro.framework.session import KedroSession
from kedro.framework.project import configure_project
from kedro.framework.startup import bootstrap_project
from airflow.decorators import dag, task
from kedro.io import MemoryDataset
from kedro.runner import SequentialRunner
from kedro.framework.project import pipelines
def run_kedro_node(package_name, pipeline_name, node_name, project_path, env, *args):
configure_project(package_name)
with KedroSession.create(project_path, env) as session:
datasets = {}
for x in args:
datasets.update(x)
if isinstance(node_name, str):
node_name = [node_name]
context = session.load_context()
catalog = context.catalog
for j, ds in datasets.items():
catalog.add(j, MemoryDataset(ds))
runner = SequentialRunner()
pipeline = pipelines[pipeline_name].filter(node_names=node_name)
output = runner.run(pipeline, catalog )
return output
# Kedro settings required to run your pipeline
env = "local"
pipeline_name = "__default__"
project_path = Path.cwd()
package_name = "space"
# Using a DAG context manager, you don't have to specify the dag property of each task
@dag(
dag_id="space-memory",
start_date=datetime(2023,1,1),
max_active_runs=3,
# https://airflow.apache.org/docs/stable/scheduler.html#dag-runs
schedule_interval="@once",
catchup=False,
# Default settings applied to all tasks
default_args=dict(
owner="airflow",
depends_on_past=False,
email_on_failure=False,
email_on_retry=False,
),
)
def kedro_dag():
@task(task_id="preprocess-companies-node")
def preprocess_companies_node(*args):
return run_kedro_node(package_name, pipeline_name, "preprocess_companies_node", project_path, env, *args)
@task(task_id="preprocess-shuttles-node")
def preprocess_shuttles_node(*args):
return run_kedro_node(package_name, pipeline_name, "preprocess_shuttles_node", project_path, env, *args)
@task(task_id="create-model-input-table-node")
def create_model_input_table_node(*args):
return run_kedro_node(package_name, pipeline_name, "create_model_input_table_node", project_path, env, *args)
@task(task_id="split-data-node")
def split_data_node(*args):
return run_kedro_node(package_name, pipeline_name, "split_data_node", project_path, env, *args)
@task(task_id="train-model-node")
def train_model_node(*args):
return run_kedro_node(package_name, pipeline_name, "train_model_node", project_path, env, *args)
@task(task_id="evaluate-model-node")
def evaluate_model_node(*args):
return run_kedro_node(package_name, pipeline_name, "evaluate_model_node", project_path, env, *args)
ds1 = preprocess_companies_node()
ds2 = preprocess_shuttles_node()
mit = create_model_input_table_node(ds1, ds2)
x = split_data_node(mit)
y = train_model_node(x)
evaluate_model_node(x, y)
kedro_dag()
Key Differences
The run_kedro_node() function
def run_kedro_node(package_name, pipeline_name, node_name, project_path, env, *args):
configure_project(package_name)
with KedroSession.create(project_path, env) as session:
datasets = {}
for x in args:
datasets.update(x)
if isinstance(node_name, str):
node_name = [node_name]
context = session.load_context()
catalog = context.catalog
for j, ds in datasets.items():
catalog.add(j, MemoryDataset(ds))
runner = SequentialRunner()
pipeline = pipelines[pipeline_name].filter(node_names=node_name)
output = runner.run(pipeline, catalog )
return output
The run_kedro_node()
function has been modified to:
- Accept additional
*args
which will be input memory datasets - Create a session
- Within the session, load the context and the catalog
- Update the catalog with the received datasets
- instantiate a runner
- Use
runner.run()
to run the node
This means that the user defined hooks don't run, we essentially rewrite what happens in a session.run()
again. It is not straightforward to inject data to a KedroSession
.
POTENTIAL SOLUTION: https://github.com/kedro-org/kedro/issues/2169 This will allow us to simply -
def run_kedro_node(package_name, pipeline_name, node_name, project_path, env, *args):
configure_project(package_name)
datasets = {}
for x in args:
datasets.update(x)
with KedroSession.create(project_path, env, datasets=datasets) as session:
return session.run()
The tasks
@task(task_id="preprocess-companies-node")
def preprocess_companies_node(*args):
return run_kedro_node(package_name, pipeline_name, "preprocess_companies_node", project_path, env, *args)
The tasks now take datasets as arguments and pass them on the the run_kedro_node()
which injects them into the session.
Task order definition
ds1 = preprocess_companies_node()
ds2 = preprocess_shuttles_node()
mit = create_model_input_table_node(ds1, ds2)
x = split_data_node(mit)
y = train_model_node(x)
evaluate_model_node(x, y)
This follows the TaskFlow API more closely, the order of the node execution is inferred from the inputs and outputs of the task.
Fantastic! This is a very detailed and high-quality analysis, thank you @ankatiyar ! I believe this topic would be excellent for discussion in our Technical Design session. In my opinion:
- for the first step, we should partially implement the TaskFlow API: remove the KedroOperator and use the @task and @dag decorators, while maintaining the task order from the traditional API. This approach, as I understand, significantly reduces the size of the code in the generated DAG and should be relatively easy to implement.
- In the second step, we could consider supporting transitions in Memory Datasets using the TaskFlow API.
- I think we should retain the task order from the traditional API for tasks that do not involve Memory Datasets
Based on the discussions, this is not worth pursuing with the current state of things. Closing this.
Upon further discussion, we will adopt certain aspects of the TaskFlow API like @dag
decorators etc without transitioning to it fully #705