kedro icon indicating copy to clipboard operation
kedro copied to clipboard

Add `--from-env` CLI flag [PoC implementation available]

Open lvijnck opened this issue 5 months ago • 8 comments

Description

Kedro's env mechanism allows for extracting environment specific configuration into different envs. We're using this to make a distinction between local runs (plugging into local systems, and sometimes swapping systems for local datasets to abstract complexity) and cloud runs.

Exampel

imagine following pipeline:

  • node_a (dataset 1) -> (dataset 2)
  • node_b (dataset 2) -> (dataset 3)

Where node_a is a very expensive model training step that takes 5+ hours and should run on specialized hardware in the cloud. while node_b is some report building.

The pipeline has two environments setup, local and cloud. The former is a run on the local computer, whereas cloud is configured to load/store data into cloud native storage systems. Given the duration of the pipeline, our local environment operates on synthetically generated data.

We use local SparkDatasets for the local run, and use BigQueryTable datasets for the cloud run for both node_a and node_b.

Now imagine trying to implement a new feature in node node_b, our local env allows for rapid iteration and prototyping, with the localized SparkDatasets and synthetic data. Occasionally, however, we'd like to execute new code on the production data that we have in the cloud environment. Kedro currently falls short here, cause the only option I have are:

  • Running in the cloud environment, which will override the contents of dataset_3 in the cloud
  • Doing some ad-hoc commenting to ensure dataset_3 is not materialized in the cloud

(note that the problem gets worse as the size of the pipeline increases)

Alternatively, we could introduce an --from-env flah, which takes the pipeline run and overrides all input catalog entries of the pipeline to those of the given env. Specifically

kedro run --env local --from-env cloud

This run will take all datasets from BigQuery as produced by the ``cloudrun ofnode_a, and materializes the outputs locally so the cloud` stays untouched.

Proposal

Though for debugging purposes, it would be nice to have a --from-env option, that allows for overriding input datasets with those from another env, e.g.,

(this was also tested with dynamic pipelines)

# session.py
"""Custom kedro session."""
from typing import Any, Iterable, Dict

from kedro import __version__ as kedro_version
from kedro.framework.session import KedroSession
from kedro.framework.session.session import KedroSessionError
from kedro.runner import AbstractRunner, SequentialRunner
from kedro.framework.project import pipelines
from kedro.io import DataCatalog


class KedroSessionWithFromCatalog(KedroSession):
    """Custom Kedro Session.

    Custom Kedro Session that allows an additional `from-catalog` to
    be specified. The from catalog overrides the catalog entry for all input
    datasets.

    NOTE: This module has some code duplication due to Kedros' complex
    config setup. We should cleanup based on:

    https://github.com/kedro-org/kedro/issues/4155
    """

    def run(  # noqa: PLR0913
        self,
        from_catalog: DataCatalog,
        pipeline_name: str | None = None,
        tags: Iterable[str] | None = None,
        runner: AbstractRunner | None = None,
        node_names: Iterable[str] | None = None,
        from_nodes: Iterable[str] | None = None,
        to_nodes: Iterable[str] | None = None,
        from_inputs: Iterable[str] | None = None,
        to_outputs: Iterable[str] | None = None,
        load_versions: dict[str, str] | None = None,
        namespace: str | None = None,
    ) -> dict[str, Any]:
        """Runs the pipeline with a specified runner.

        Args:
            from_catalog: From catalog to use, if set will override input datasets.
            from_params: From params to set, will override params.
            pipeline_name: Name of the pipeline that is being run.
            tags: An optional list of node tags which should be used to
                filter the nodes of the ``Pipeline``. If specified, only the nodes
                containing *any* of these tags will be run.
            runner: An optional parameter specifying the runner that you want to run
                the pipeline with.
            node_names: An optional list of node names which should be used to
                filter the nodes of the ``Pipeline``. If specified, only the nodes
                with these names will be run.
            from_nodes: An optional list of node names which should be used as a
                starting point of the new ``Pipeline``.
            to_nodes: An optional list of node names which should be used as an
                end point of the new ``Pipeline``.
            from_inputs: An optional list of input datasets which should be
                used as a starting point of the new ``Pipeline``.
            to_outputs: An optional list of output datasets which should be
                used as an end point of the new ``Pipeline``.
            load_versions: An optional flag to specify a particular dataset
                version timestamp to load.
            namespace: The namespace of the nodes that is being run.

        Raises:
            ValueError: If the named or `__default__` pipeline is not
                defined by `register_pipelines`.
            Exception: Any uncaught exception during the run will be re-raised
                after being passed to ``on_pipeline_error`` hook.
            KedroSessionError: If more than one run is attempted to be executed during
                a single session.

        Returns:
            Any node outputs that cannot be processed by the ``DataCatalog``.
            These are returned in a dictionary, where the keys are defined
            by the node outputs.
        """
        # Report project name
        self._logger.info("Kedro project %s", self._project_path.name)

        if self._run_called:
            raise KedroSessionError(
                "A run has already been completed as part of the"
                " active KedroSession. KedroSession has a 1-1 mapping with"
                " runs, and thus only one run should be executed per session."
            )

        session_id = self.store["session_id"]
        save_version = session_id
        extra_params = self.store.get("extra_params") or {}
        context = self.load_context()

        name = pipeline_name or "__default__"

        try:
            pipeline = pipelines[name]
        except KeyError as exc:
            raise ValueError(
                f"Failed to find the pipeline named '{name}'. "
                f"It needs to be generated and returned "
                f"by the 'register_pipelines' function."
            ) from exc

        filtered_pipeline = pipeline.filter(
            tags=tags,
            from_nodes=from_nodes,
            to_nodes=to_nodes,
            node_names=node_names,
            from_inputs=from_inputs,
            to_outputs=to_outputs,
            node_namespace=namespace,
        )

        record_data = {
            "session_id": session_id,
            "project_path": self._project_path.as_posix(),
            "env": context.env,
            "kedro_version": kedro_version,
            "tags": tags,
            "from_nodes": from_nodes,
            "to_nodes": to_nodes,
            "node_names": node_names,
            "from_inputs": from_inputs,
            "to_outputs": to_outputs,
            "load_versions": load_versions,
            "extra_params": extra_params,
            "pipeline_name": pipeline_name,
            "namespace": namespace,
            "runner": getattr(runner, "__name__", str(runner)),
        }

        catalog = context._get_catalog(
            save_version=save_version,
            load_versions=load_versions,
        )

        if from_catalog:
            # Update all pipeline inputs to read from
            # the from catalog
            for item in filtered_pipeline.inputs():
                self._logger.info("Replacing %s", item)
                catalog.add(item, from_catalog._get_dataset(item), replace=True)

        # Run the runner
        hook_manager = self._hook_manager
        runner = runner or SequentialRunner()
        if not isinstance(runner, AbstractRunner):
            raise KedroSessionError(
                "KedroSession expect an instance of Runner instead of a class."
                "Have you forgotten the `()` at the end of the statement?"
            )
        hook_manager.hook.before_pipeline_run(
            run_params=record_data, pipeline=filtered_pipeline, catalog=catalog
        )

        try:
            run_result = runner.run(
                filtered_pipeline, catalog, hook_manager, session_id
            )
            self._run_called = True
        except Exception as error:
            hook_manager.hook.on_pipeline_error(
                error=error,
                run_params=record_data,
                pipeline=filtered_pipeline,
                catalog=catalog,
            )
            raise

        hook_manager.hook.after_pipeline_run(
            run_params=record_data,
            run_result=run_result,
            pipeline=filtered_pipeline,
            catalog=catalog,
        )
        return run_result
"""Command line tools for manipulating a Kedro project.

Intended to be invoked via `kedro`.
"""
"""Command line tools for manipulating a Kedro project.

Intended to be invoked via `kedro`.
"""
from typing import List, Set, Dict, Any
import click
from kedro.framework.cli.project import (
    ASYNC_ARG_HELP,
    CONFIG_FILE_HELP,
    CONF_SOURCE_HELP,
    FROM_INPUTS_HELP,
    FROM_NODES_HELP,
    LOAD_VERSION_HELP,
    NODE_ARG_HELP,
    PARAMS_ARG_HELP,
    PIPELINE_ARG_HELP,
    RUNNER_ARG_HELP,
    TAG_ARG_HELP,
    TO_NODES_HELP,
    TO_OUTPUTS_HELP,
    project_group,
)
from kedro.framework.cli.utils import (
    CONTEXT_SETTINGS,
    _config_file_callback,
    _split_params,
    _split_load_versions,
    env_option,
    split_string,
    split_node_names,
)
from kedro.utils import load_obj
from kedro.pipeline.pipeline import Pipeline
from kedro.framework.project import pipelines, settings
from kedro.framework.context.context import _convert_paths_to_absolute_posix

from matrix.session import KedroSessionWithFromCatalog
from kedro.io import DataCatalog


@click.group(context_settings=CONTEXT_SETTINGS, name=__file__)
def cli():
    """Command line tools for manipulating a Kedro project."""


@project_group.command()
@click.option(
    "--from-inputs", type=str, default="", help=FROM_INPUTS_HELP, callback=split_string
)
@click.option(
    "--to-outputs", type=str, default="", help=TO_OUTPUTS_HELP, callback=split_string
)
@click.option(
    "--from-nodes",
    type=str,
    default="",
    help=FROM_NODES_HELP,
    callback=split_node_names,
)
@click.option(
    "--to-nodes", type=str, default="", help=TO_NODES_HELP, callback=split_node_names
)
@click.option(
    "--nodes",
    "-n",
    "node_names",
    type=str,
    multiple=False,
    help=NODE_ARG_HELP,
    callback=split_string,
    default="",
)
@click.option(
    "--runner", "-r", type=str, default=None, multiple=False, help=RUNNER_ARG_HELP
)
@click.option("--async", "is_async", is_flag=True, multiple=False, help=ASYNC_ARG_HELP)
@env_option
@click.option("--tags", "-t", type=str, multiple=True, help=TAG_ARG_HELP)
@click.option(
    "--without-tags",
    type=str,
    help="used to filter out nodes with tags that should not be run. All dependent downstream nodes are also removed. Note nodes need to have _all_ tags to be removed.",
    callback=split_string,
    default=[],
)
@click.option(
    "--load-versions",
    "-lv",
    type=str,
    multiple=True,
    help=LOAD_VERSION_HELP,
    callback=_split_load_versions,
)
@click.option("--pipeline", "-p", type=str, default=None, help=PIPELINE_ARG_HELP)
@click.option(
    "--config",
    "-c",
    type=click.Path(exists=True, dir_okay=False, resolve_path=True),
    help=CONFIG_FILE_HELP,
    callback=_config_file_callback,
)
@click.option(
    "--conf-source",
    type=click.Path(exists=True, file_okay=False, resolve_path=True),
    help=CONF_SOURCE_HELP,
)
@click.option(
    "--params",
    type=click.UNPROCESSED,
    default="",
    help=PARAMS_ARG_HELP,
    callback=_split_params,
)
@click.option(
    "--from-env",
    type=str,
    default=None,
    help="Custom env to read from, if specified will read from the `--from-env` and write to the `--env`",
)
def run(
    tags,
    without_tags,
    env,
    runner,
    is_async,
    node_names,
    to_nodes,
    from_nodes,
    from_inputs,
    to_outputs,
    load_versions,
    pipeline,
    config,
    conf_source,
    params,
    from_env,
):
    """Run the pipeline."""
    if pipeline in ["test", "fabricator"] and env in [None, "base"]:
        raise RuntimeError(
            "Running the fabricator in the base environment might overwrite production data! Use the test env `-e test` instead."
        )

    runner = load_obj(runner or "SequentialRunner", "kedro.runner")
    tags = tuple(tags)
    without_tags = without_tags
    node_names = tuple(node_names)

    with KedroSessionWithFromCatalog.create(
        env=env, conf_source=conf_source, extra_params=params
    ) as session:
        # introduced to filter out tags that should not be run
        node_names = _filter_nodes_missing_tag(
            tuple(without_tags), pipeline, session, node_names
        )

        from_catalog = None
        from_params = {}
        if from_env:
            # Load second config loader instance
            config_loader_class = settings.CONFIG_LOADER_CLASS
            config_loader = config_loader_class(  # type: ignore[no-any-return]
                conf_source=session._conf_source,
                env=from_env,
                **settings.CONFIG_LOADER_ARGS,
            )
            conf_catalog = config_loader["catalog"]
            conf_catalog = _convert_paths_to_absolute_posix(
                project_path=session._project_path, conf_dictionary=conf_catalog
            )
            conf_creds = config_loader["credentials"]
            from_catalog: DataCatalog = settings.DATA_CATALOG_CLASS.from_config(
                catalog=conf_catalog, credentials=conf_creds
            )
            from_params = config_loader["parameters"]
            from_catalog.add_feed_dict(_get_feed_dict(from_params), replace=True)

        session.run(
            from_catalog=from_catalog,
            tags=tags,
            runner=runner(is_async=is_async),
            node_names=node_names,
            from_nodes=from_nodes,
            to_nodes=to_nodes,
            from_inputs=from_inputs,
            to_outputs=to_outputs,
            load_versions=load_versions,
            pipeline_name=pipeline,
        )


def _get_feed_dict(params: Dict) -> dict[str, Any]:
    """Get parameters and return the feed dictionary."""
    feed_dict = {"parameters": params}

    @staticmethod
    def _add_param_to_feed_dict(param_name: str, param_value: Any) -> None:
        """Add param to feed dict.

        This recursively adds parameter paths to the `feed_dict`,
        whenever `param_value` is a dictionary itself, so that users can
        specify specific nested parameters in their node inputs.

        Example:
            >>> param_name = "a"
            >>> param_value = {"b": 1}
            >>> _add_param_to_feed_dict(param_name, param_value)
            >>> assert feed_dict["params:a"] == {"b": 1}
            >>> assert feed_dict["params:a.b"] == 1
        """
        key = f"params:{param_name}"
        feed_dict[key] = param_value
        if isinstance(param_value, dict):
            for key, val in param_value.items():
                _add_param_to_feed_dict(f"{param_name}.{key}", val)

    for param_name, param_value in params.items():
        _add_param_to_feed_dict(param_name, param_value)

    return feed_dict

lvijnck avatar Sep 10 '24 19:09 lvijnck