kedro-plugins icon indicating copy to clipboard operation
kedro-plugins copied to clipboard

SparkJDBCDataset not working when specify a query instead of a table

Open DavidRetana-TomTom opened this issue 1 year ago • 7 comments

Description

When using SparkJDBCDataset you need to specify table name as a mandatory parameter. However, using the spark JDBC connector directly, you can specify a query to retrieve data from the database instead of hardcoding a single table. Check out this link. According to the official Spark documentation:

The specified query will be parenthesized and used as a subquery in the FROM clause. Below are a couple of restrictions while using this option.

  1. It is not allowed to specify dbtable and query options at the same time.
  2. It is not allowed to specify query and partitionColumn options at the same time. When specifying partitionColumn option is required, the subquery can be specified using dbtable option instead and partition columns can be qualified using the subquery alias provided as part of dbtable. Example:
spark.read.format("jdbc")
.option("url", jdbcUrl)
.option("query", "select c1, c2 from t1")
.load()

Context

This is specially important if you want to read data from multiple tables in the database or if you want to run complex or spatial queries in the database instead of retrieving all the data and perform the computations in the cluster.

Steps to Reproduce

Source code right now (https://github.com/kedro-org/kedro-plugins/blob/main/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py):

if not table:
    raise DatasetError(
        "'table' argument cannot be empty. Please "
        "provide the name of the table to load or save "
        "data to."
    )

Expected Result

I would like to have something like the following:

weather:
  type: spark.SparkJDBCDataSet
  query: SELECT field1, field2 FROM weather_table WHERE <condition>
  url: jdbc:postgresql://localhost/test
  credentials: db_credentials
  load_args:
    properties:
      driver: org.postgresql.Driver
  save_args:
    properties:
      driver: org.postgresql.Driver

Your Environment

Include as many relevant details about the environment in which you experienced the bug:

  • Kedro version used (pip show kedro or kedro -V): 0.19.3
  • Kedro plugin and kedro plugin version used (pip show kedro-airflow):
  • Python version used (python -V): 3.10
  • Operating system and version: aarch64 GNU/Linux

DavidRetana-TomTom avatar Apr 09 '24 12:04 DavidRetana-TomTom

@DavidRetana-TomTom Did you mean that you expect there is a query argument? I am not sure what's the feature request here.

noklam avatar Apr 09 '24 13:04 noklam

@DavidRetana-TomTom Did you mean that you expect there is a query argument? I am not sure what's the feature request here.

Yes exactly

DavidRetana-TomTom avatar Apr 09 '24 13:04 DavidRetana-TomTom

@DavidRetana-TomTom this is a great push - this dataset is quite old so this may be newer functionality. I think it's a good idea to add this to our implementation.

There are two steps at this point:

  1. The quickest way to unblock yourself is to copy the dataset implementation from here into your project, change the reference in YAML to a local class path and then update the logic to accept a query param like you need.
  2. We'd really appreciate a contribution back to kedro-datasets would you be interested in doing this? We'd be here to coach you through the process.

ow do you feel about raising a PR to make this work? We can coach you through the process.

datajoely avatar Apr 09 '24 13:04 datajoely

I take what is described here and hopefully this can be a starting point or workaround, I only implemented the load method. The change is in noklam/sparkjdbcdataset-not-working-639

The diff: https://github.com/kedro-org/kedro-plugins/compare/noklam/sparkjdbcdataset-not-working-639?expand=1

Details

"""SparkJDBCDataset to load and save a PySpark DataFrame via JDBC."""

from copy import deepcopy from typing import Any

from kedro.io.core import AbstractDataset, DatasetError from pyspark.sql import DataFrame

from kedro_datasets.spark.spark_dataset import _get_spark

class SparkJDBCDataset(AbstractDataset[DataFrame, DataFrame]): """SparkJDBCDataset loads data from a database table accessible via JDBC URL url and connection properties and saves the content of a PySpark DataFrame to an external database table via JDBC. It uses pyspark.sql.DataFrameReader and pyspark.sql.DataFrameWriter internally, so it supports all allowed PySpark options on jdbc.

Example usage for the
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog_yaml_examples.html>`_:

.. code-block:: yaml

    weather:
      type: spark.SparkJDBCDataset
      table: weather_table
      url: jdbc:postgresql://localhost/test
      credentials: db_credentials
      load_args:
        properties:
          driver: org.postgresql.Driver
      save_args:
        properties:
          driver: org.postgresql.Driver

Example usage for the
`Python API <https://kedro.readthedocs.io/en/stable/data/\
advanced_data_catalog_usage.html>`_:

.. code-block:: pycon

    >>> import pandas as pd
    >>> from kedro_datasets.spark import SparkJDBCDataset
    >>> from pyspark.sql import SparkSession
    >>>
    >>> spark = SparkSession.builder.getOrCreate()
    >>> data = spark.createDataFrame(
    ...     pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
    ... )
    >>> url = "jdbc:postgresql://localhost/test"
    >>> table = "table_a"
    >>> connection_properties = {"driver": "org.postgresql.Driver"}
    >>> dataset = SparkJDBCDataset(
    ...     url=url,
    ...     table=table,
    ...     credentials={"user": "scott", "password": "tiger"},
    ...     load_args={"properties": connection_properties},
    ...     save_args={"properties": connection_properties},
    ... )
    >>>
    >>> dataset.save(data)
    >>> reloaded = dataset.load()
    >>>
    >>> assert data.toPandas().equals(reloaded.toPandas())

"""

DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}

def __init__(  # noqa: PLR0913
    self,
    *,
    url: str,
    table: str,
    credentials: dict[str, Any] = None,
    load_args: dict[str, Any] = None,
    save_args: dict[str, Any] = None,
    metadata: dict[str, Any] = None,
    query: str = None

) -> None:
    """Creates a new ``SparkJDBCDataset``.

    Args:
        url: A JDBC URL of the form ``jdbc:subprotocol:subname``.
        table: The name of the table to load or save data to.
        credentials: A dictionary of JDBC database connection arguments.
            Normally at least properties ``user`` and ``password`` with
            their corresponding values.  It updates ``properties``
            parameter in ``load_args`` and ``save_args`` in case it is
            provided.
        load_args: Provided to underlying PySpark ``jdbc`` function along
            with the JDBC URL and the name of the table. To find all
            supported arguments, see here:
            https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.jdbc.html
        save_args: Provided to underlying PySpark ``jdbc`` function along
            with the JDBC URL and the name of the table. To find all
            supported arguments, see here:
            https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.jdbc.html
        metadata: Any arbitrary metadata.
            This is ignored by Kedro, but may be consumed by users or external plugins.

    Raises:
        DatasetError: When either ``url`` or ``table`` is empty or
            when a property is provided with a None value.
    """

    if not url:
        raise DatasetError(
            "'url' argument cannot be empty. Please "
            "provide a JDBC URL of the form "
            "'jdbc:subprotocol:subname'."
        )

    if not table and not query:
        raise DatasetError(
            "'table'  and 'query' argument cannot be both empty. Please "
            "provide the name of the table to load or save "
            "data to."
        )

    if  table and  query:
        raise DatasetError(
            "Only one of 'table'  and 'query' argument should be used. Please "
            "provide the name of the table or a query."
        )

    self._url = url
    self._table = table
    self._query = query

    self.metadata = metadata

    # Handle default load and save arguments
    self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
    if load_args is not None:
        self._load_args.update(load_args)
    self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
    if save_args is not None:
        self._save_args.update(save_args)

    # Update properties in load_args and save_args with credentials.
    if credentials is not None:
        # Check credentials for bad inputs.
        for cred_key, cred_value in credentials.items():
            if cred_value is None:
                raise DatasetError(
                    f"Credential property '{cred_key}' cannot be None. "
                    f"Please provide a value."
                )

        load_properties = self._load_args.get("properties", {})
        save_properties = self._save_args.get("properties", {})
        self._load_args["properties"] = {**load_properties, **credentials}
        self._save_args["properties"] = {**save_properties, **credentials}

def _describe(self) -> dict[str, Any]:
    load_args = self._load_args
    save_args = self._save_args

    # Remove user and password values from load and save properties.
    if "properties" in load_args:
        load_properties = load_args["properties"].copy()
        load_properties.pop("user", None)
        load_properties.pop("password", None)
        load_args = {**load_args, "properties": load_properties}
    if "properties" in save_args:
        save_properties = save_args["properties"].copy()
        save_properties.pop("user", None)
        save_properties.pop("password", None)
        save_args = {**save_args, "properties": save_properties}

    return {
        "url": self._url,
        "table": self._table,
        "load_args": load_args,
        "save_args": save_args,
    }

def _load(self) -> DataFrame:
    if self._table:
        return _get_spark().read.jdbc(self._url, self._table, **self._load_args)
    if self._query:
        return _get_spark().read.format("jdbc").option("url", self._url).option("query", self._query).load()


def _save(self, data: DataFrame) -> None:
    return data.write.jdbc(self._url, self._table, **self._save_args)

noklam avatar Apr 09 '24 13:04 noklam

I take what is described here and hopefully this can be a starting point or workaround, I only implemented the load method. The change is in noklam/sparkjdbcdataset-not-working-639

The diff: https://github.com/kedro-org/kedro-plugins/compare/noklam/sparkjdbcdataset-not-working-639?expand=1

Details

That should be enough for my use case. I can't open a pull request because I am not a collaborator of this project.

DavidRetana-TomTom avatar Apr 29 '24 12:04 DavidRetana-TomTom

@DavidRetana-TomTom you can open one via the Forking workflow! We'd really appreciate it if you have a chance

datajoely avatar Apr 29 '24 13:04 datajoely