kedro
kedro copied to clipboard
Add `--from-env` CLI flag [PoC implementation available]
Description
Kedro's env mechanism allows for extracting environment specific configuration into different envs. We're using this to make a distinction between local runs (plugging into local systems, and sometimes swapping systems for local datasets to abstract complexity) and cloud runs.
Exampel
imagine following pipeline:
-
node_a
(dataset 1) -> (dataset 2) -
node_b
(dataset 2) -> (dataset 3)
Where node_a
is a very expensive model training step that takes 5+ hours and should run on specialized hardware in the cloud. while node_b
is some report building.
The pipeline has two environments setup, local
and cloud
. The former is a run on the local computer, whereas cloud
is configured to load/store data into cloud native storage systems. Given the duration of the pipeline, our local
environment operates on synthetically generated data.
We use local SparkDatasets
for the local
run, and use BigQueryTable
datasets for the cloud
run for both node_a
and node_b
.
Now imagine trying to implement a new feature in node node_b
, our local
env allows for rapid iteration and prototyping, with the localized SparkDatasets
and synthetic data. Occasionally, however, we'd like to execute new code on the production data that we have in the cloud
environment. Kedro currently falls short here, cause the only option I have are:
- Running in the cloud environment, which will override the contents of
dataset_3
in the cloud - Doing some ad-hoc commenting to ensure
dataset_3
is not materialized in the cloud
(note that the problem gets worse as the size of the pipeline increases)
Alternatively, we could introduce an --from-env
flah, which takes the pipeline run and overrides all input catalog entries of the pipeline to those of the given env. Specifically
kedro run --env local --from-env cloud
This run will take all datasets from BigQuery as produced by the ``cloudrun of
node_a, and materializes the outputs locally so the
cloud` stays untouched.
Proposal
Though for debugging purposes, it would be nice to have a --from-env
option, that allows for overriding input datasets with those from another env, e.g.,
(this was also tested with dynamic pipelines)
# session.py
"""Custom kedro session."""
from typing import Any, Iterable, Dict
from kedro import __version__ as kedro_version
from kedro.framework.session import KedroSession
from kedro.framework.session.session import KedroSessionError
from kedro.runner import AbstractRunner, SequentialRunner
from kedro.framework.project import pipelines
from kedro.io import DataCatalog
class KedroSessionWithFromCatalog(KedroSession):
"""Custom Kedro Session.
Custom Kedro Session that allows an additional `from-catalog` to
be specified. The from catalog overrides the catalog entry for all input
datasets.
NOTE: This module has some code duplication due to Kedros' complex
config setup. We should cleanup based on:
https://github.com/kedro-org/kedro/issues/4155
"""
def run( # noqa: PLR0913
self,
from_catalog: DataCatalog,
pipeline_name: str | None = None,
tags: Iterable[str] | None = None,
runner: AbstractRunner | None = None,
node_names: Iterable[str] | None = None,
from_nodes: Iterable[str] | None = None,
to_nodes: Iterable[str] | None = None,
from_inputs: Iterable[str] | None = None,
to_outputs: Iterable[str] | None = None,
load_versions: dict[str, str] | None = None,
namespace: str | None = None,
) -> dict[str, Any]:
"""Runs the pipeline with a specified runner.
Args:
from_catalog: From catalog to use, if set will override input datasets.
from_params: From params to set, will override params.
pipeline_name: Name of the pipeline that is being run.
tags: An optional list of node tags which should be used to
filter the nodes of the ``Pipeline``. If specified, only the nodes
containing *any* of these tags will be run.
runner: An optional parameter specifying the runner that you want to run
the pipeline with.
node_names: An optional list of node names which should be used to
filter the nodes of the ``Pipeline``. If specified, only the nodes
with these names will be run.
from_nodes: An optional list of node names which should be used as a
starting point of the new ``Pipeline``.
to_nodes: An optional list of node names which should be used as an
end point of the new ``Pipeline``.
from_inputs: An optional list of input datasets which should be
used as a starting point of the new ``Pipeline``.
to_outputs: An optional list of output datasets which should be
used as an end point of the new ``Pipeline``.
load_versions: An optional flag to specify a particular dataset
version timestamp to load.
namespace: The namespace of the nodes that is being run.
Raises:
ValueError: If the named or `__default__` pipeline is not
defined by `register_pipelines`.
Exception: Any uncaught exception during the run will be re-raised
after being passed to ``on_pipeline_error`` hook.
KedroSessionError: If more than one run is attempted to be executed during
a single session.
Returns:
Any node outputs that cannot be processed by the ``DataCatalog``.
These are returned in a dictionary, where the keys are defined
by the node outputs.
"""
# Report project name
self._logger.info("Kedro project %s", self._project_path.name)
if self._run_called:
raise KedroSessionError(
"A run has already been completed as part of the"
" active KedroSession. KedroSession has a 1-1 mapping with"
" runs, and thus only one run should be executed per session."
)
session_id = self.store["session_id"]
save_version = session_id
extra_params = self.store.get("extra_params") or {}
context = self.load_context()
name = pipeline_name or "__default__"
try:
pipeline = pipelines[name]
except KeyError as exc:
raise ValueError(
f"Failed to find the pipeline named '{name}'. "
f"It needs to be generated and returned "
f"by the 'register_pipelines' function."
) from exc
filtered_pipeline = pipeline.filter(
tags=tags,
from_nodes=from_nodes,
to_nodes=to_nodes,
node_names=node_names,
from_inputs=from_inputs,
to_outputs=to_outputs,
node_namespace=namespace,
)
record_data = {
"session_id": session_id,
"project_path": self._project_path.as_posix(),
"env": context.env,
"kedro_version": kedro_version,
"tags": tags,
"from_nodes": from_nodes,
"to_nodes": to_nodes,
"node_names": node_names,
"from_inputs": from_inputs,
"to_outputs": to_outputs,
"load_versions": load_versions,
"extra_params": extra_params,
"pipeline_name": pipeline_name,
"namespace": namespace,
"runner": getattr(runner, "__name__", str(runner)),
}
catalog = context._get_catalog(
save_version=save_version,
load_versions=load_versions,
)
if from_catalog:
# Update all pipeline inputs to read from
# the from catalog
for item in filtered_pipeline.inputs():
self._logger.info("Replacing %s", item)
catalog.add(item, from_catalog._get_dataset(item), replace=True)
# Run the runner
hook_manager = self._hook_manager
runner = runner or SequentialRunner()
if not isinstance(runner, AbstractRunner):
raise KedroSessionError(
"KedroSession expect an instance of Runner instead of a class."
"Have you forgotten the `()` at the end of the statement?"
)
hook_manager.hook.before_pipeline_run(
run_params=record_data, pipeline=filtered_pipeline, catalog=catalog
)
try:
run_result = runner.run(
filtered_pipeline, catalog, hook_manager, session_id
)
self._run_called = True
except Exception as error:
hook_manager.hook.on_pipeline_error(
error=error,
run_params=record_data,
pipeline=filtered_pipeline,
catalog=catalog,
)
raise
hook_manager.hook.after_pipeline_run(
run_params=record_data,
run_result=run_result,
pipeline=filtered_pipeline,
catalog=catalog,
)
return run_result
"""Command line tools for manipulating a Kedro project.
Intended to be invoked via `kedro`.
"""
"""Command line tools for manipulating a Kedro project.
Intended to be invoked via `kedro`.
"""
from typing import List, Set, Dict, Any
import click
from kedro.framework.cli.project import (
ASYNC_ARG_HELP,
CONFIG_FILE_HELP,
CONF_SOURCE_HELP,
FROM_INPUTS_HELP,
FROM_NODES_HELP,
LOAD_VERSION_HELP,
NODE_ARG_HELP,
PARAMS_ARG_HELP,
PIPELINE_ARG_HELP,
RUNNER_ARG_HELP,
TAG_ARG_HELP,
TO_NODES_HELP,
TO_OUTPUTS_HELP,
project_group,
)
from kedro.framework.cli.utils import (
CONTEXT_SETTINGS,
_config_file_callback,
_split_params,
_split_load_versions,
env_option,
split_string,
split_node_names,
)
from kedro.utils import load_obj
from kedro.pipeline.pipeline import Pipeline
from kedro.framework.project import pipelines, settings
from kedro.framework.context.context import _convert_paths_to_absolute_posix
from matrix.session import KedroSessionWithFromCatalog
from kedro.io import DataCatalog
@click.group(context_settings=CONTEXT_SETTINGS, name=__file__)
def cli():
"""Command line tools for manipulating a Kedro project."""
@project_group.command()
@click.option(
"--from-inputs", type=str, default="", help=FROM_INPUTS_HELP, callback=split_string
)
@click.option(
"--to-outputs", type=str, default="", help=TO_OUTPUTS_HELP, callback=split_string
)
@click.option(
"--from-nodes",
type=str,
default="",
help=FROM_NODES_HELP,
callback=split_node_names,
)
@click.option(
"--to-nodes", type=str, default="", help=TO_NODES_HELP, callback=split_node_names
)
@click.option(
"--nodes",
"-n",
"node_names",
type=str,
multiple=False,
help=NODE_ARG_HELP,
callback=split_string,
default="",
)
@click.option(
"--runner", "-r", type=str, default=None, multiple=False, help=RUNNER_ARG_HELP
)
@click.option("--async", "is_async", is_flag=True, multiple=False, help=ASYNC_ARG_HELP)
@env_option
@click.option("--tags", "-t", type=str, multiple=True, help=TAG_ARG_HELP)
@click.option(
"--without-tags",
type=str,
help="used to filter out nodes with tags that should not be run. All dependent downstream nodes are also removed. Note nodes need to have _all_ tags to be removed.",
callback=split_string,
default=[],
)
@click.option(
"--load-versions",
"-lv",
type=str,
multiple=True,
help=LOAD_VERSION_HELP,
callback=_split_load_versions,
)
@click.option("--pipeline", "-p", type=str, default=None, help=PIPELINE_ARG_HELP)
@click.option(
"--config",
"-c",
type=click.Path(exists=True, dir_okay=False, resolve_path=True),
help=CONFIG_FILE_HELP,
callback=_config_file_callback,
)
@click.option(
"--conf-source",
type=click.Path(exists=True, file_okay=False, resolve_path=True),
help=CONF_SOURCE_HELP,
)
@click.option(
"--params",
type=click.UNPROCESSED,
default="",
help=PARAMS_ARG_HELP,
callback=_split_params,
)
@click.option(
"--from-env",
type=str,
default=None,
help="Custom env to read from, if specified will read from the `--from-env` and write to the `--env`",
)
def run(
tags,
without_tags,
env,
runner,
is_async,
node_names,
to_nodes,
from_nodes,
from_inputs,
to_outputs,
load_versions,
pipeline,
config,
conf_source,
params,
from_env,
):
"""Run the pipeline."""
if pipeline in ["test", "fabricator"] and env in [None, "base"]:
raise RuntimeError(
"Running the fabricator in the base environment might overwrite production data! Use the test env `-e test` instead."
)
runner = load_obj(runner or "SequentialRunner", "kedro.runner")
tags = tuple(tags)
without_tags = without_tags
node_names = tuple(node_names)
with KedroSessionWithFromCatalog.create(
env=env, conf_source=conf_source, extra_params=params
) as session:
# introduced to filter out tags that should not be run
node_names = _filter_nodes_missing_tag(
tuple(without_tags), pipeline, session, node_names
)
from_catalog = None
from_params = {}
if from_env:
# Load second config loader instance
config_loader_class = settings.CONFIG_LOADER_CLASS
config_loader = config_loader_class( # type: ignore[no-any-return]
conf_source=session._conf_source,
env=from_env,
**settings.CONFIG_LOADER_ARGS,
)
conf_catalog = config_loader["catalog"]
conf_catalog = _convert_paths_to_absolute_posix(
project_path=session._project_path, conf_dictionary=conf_catalog
)
conf_creds = config_loader["credentials"]
from_catalog: DataCatalog = settings.DATA_CATALOG_CLASS.from_config(
catalog=conf_catalog, credentials=conf_creds
)
from_params = config_loader["parameters"]
from_catalog.add_feed_dict(_get_feed_dict(from_params), replace=True)
session.run(
from_catalog=from_catalog,
tags=tags,
runner=runner(is_async=is_async),
node_names=node_names,
from_nodes=from_nodes,
to_nodes=to_nodes,
from_inputs=from_inputs,
to_outputs=to_outputs,
load_versions=load_versions,
pipeline_name=pipeline,
)
def _get_feed_dict(params: Dict) -> dict[str, Any]:
"""Get parameters and return the feed dictionary."""
feed_dict = {"parameters": params}
@staticmethod
def _add_param_to_feed_dict(param_name: str, param_value: Any) -> None:
"""Add param to feed dict.
This recursively adds parameter paths to the `feed_dict`,
whenever `param_value` is a dictionary itself, so that users can
specify specific nested parameters in their node inputs.
Example:
>>> param_name = "a"
>>> param_value = {"b": 1}
>>> _add_param_to_feed_dict(param_name, param_value)
>>> assert feed_dict["params:a"] == {"b": 1}
>>> assert feed_dict["params:a.b"] == 1
"""
key = f"params:{param_name}"
feed_dict[key] = param_value
if isinstance(param_value, dict):
for key, val in param_value.items():
_add_param_to_feed_dict(f"{param_name}.{key}", val)
for param_name, param_value in params.items():
_add_param_to_feed_dict(param_name, param_value)
return feed_dict