pandera icon indicating copy to clipboard operation
pandera copied to clipboard

Support conversion from DataFrameModel to PySpark StructType

Open Garett601 opened this issue 1 year ago • 3 comments

Is your feature request related to a problem? Please describe.

When creating pandera.pyspark schemas, if you would like to create a PySpark DataFrame you have to explicitly define the StructType that matches the schema definition. This causes duplication of work and is very time consuming, especially when you have a final schema that inherits from other schemas. This issue/enhancement was suggeted in Support conversion from Pandera DataFrameSchema or DataFrameModel to PySpark StructType #1327.

Describe the solution you'd like

I have written a function that could possibly be included in the pandera package to make this process quick and easy. See the code below for the proposed solution, which also clearly shows the intended behaviour:

import pyspark.sql.types as T
from pandera.pyspark import DataFrameModel
from pyspark.sql.types import (
    DateType,
    FloatType,
    IntegerType,
    StringType,
    StructField,
    StructType,
)
from typing import Type, Set, List


def get_spark_schema(cls: Type[DataFrameModel]) -> StructType:
    """
    Generates a Spark StructType schema from a DataFrameModel class and its base classes, using their annotations.

    This function iterates over the Method Resolution Order (MRO) of the provided class. It examines each base class
    that is a subclass of DataFrameModel (excluding DataFrameModel itself). For each, it extracts annotated fields,
    maps them to Spark DataTypes, and creates a StructField for each. The function ensures uniqueness of each field in
    the resulting schema.

    Parameters
    ----------
    cls : DataFrameModel
        The class inheriting from DataFrameModel for which the Spark schema is to be generated.

    Returns
    -------
    StructType
        A Spark StructType schema with StructFields for each annotated field in the class hierarchy.

    Notes
    -----
    The function uses guard clauses for reduced nesting and improved readability. All fields in the resulting schema
    are marked as nullable.

    Examples
    --------
    >>> class MyDataFrameModel(DataFrameModel):
    ...     name: T.StringType
    ...     age: T.IntegerType
    >>> schema = get_spark_schema(MyDataFrameModel)
    >>> print(schema)
    StructType([StructField('name', StringType(), True),
                StructField('age', IntegerType(), True),
                StructField('Prediction_Date', DateType(), True)])
    """
    fields: List[StructField] = []
    added_fields: Set[str] = set()
    nullable: bool = True

    type_mapping: dict = {
        T.StringType: StringType(),
        T.IntegerType: IntegerType(),
        T.FloatType: FloatType(),
        T.DateType: DateType(),
    }

    for base in cls.__mro__:
        if not issubclass(base, DataFrameModel) or base == DataFrameModel:
            continue

        for attr_name, attr_type in base.__annotations__.items():
            spark_type = type_mapping.get(attr_type, None)
            if attr_name in added_fields or spark_type is None:
                continue

            fields.append(StructField(attr_name, spark_type, nullable))
            added_fields.add(attr_name)

    return StructType(fields)

Describe alternatives you've considered

The alternative is defining a get_spark_schema classmethod within each schema definition, but this still requires duplication of work and is time consuming.

Additional context

Example Usage and Outputs

import pandera.pyspark as pa
import pyspark.sql.types as T
from pandera.pyspark import DataFrameModel
from pyspark.sql.types import (
    DateType,
    FloatType,
    IntegerType,
    StringType,
)

class ExampleSchema_1(DataFrameModel):
    """Schema defining columns."""

    Column_1: T.StringType = pa.Field(coerce=True)
    Column_2: T.IntegerType = pa.Field(coerce=True)


class ExampleSchema_2(DataFrameModel):
    """Schema defining columns."""

    Column_3: T.FloatType = pa.Field(coerce=True)
    Column_4: T.DateType = pa.Field(coerce=True)

class ExampleSchema_3(DataFrameModel):
    """Schema defining columns."""

    Column_1: T.StringType = pa.Field(coerce=True)
    Column_2: T.IntegerType = pa.Field(coerce=True)
    Column_3: T.FloatType = pa.Field(coerce=True)
    Column_4: T.DateType = pa.Field(coerce=True)
    Column_5: T.StringType = pa.Field(coerce=True)
    Column_6: T.IntegerType = pa.Field(coerce=True)
    Column_7: T.FloatType = pa.Field(coerce=True)
    Column_8: T.DateType = pa.Field(coerce=True)


class FinalExampleSchema(ExampleSchema_1, ExampleSchema_2, ExampleSchema_3):
    """Final schema that inherits from all other schemas."""

This section demonstrates the usage of the get_spark_schema function with different schema examples and shows the corresponding outputs to illustrate its functionality.

# Example Schema 1 Usage
get_spark_schema(ExampleSchema_1)
# Output: StructType([StructField('Column_1', StringType(), True), StructField('Column_2', IntegerType(), True)])

# Example Schema 2 Usage
get_spark_schema(ExampleSchema_2)
# Output: StructType([StructField('Column_3', FloatType(), True), StructField('Column_4', DateType(), True)])

# Example Schema 3 Usage
get_spark_schema(ExampleSchema_3)
# Output: StructType([StructField('Column_1', StringType(), True), StructField('Column_2', IntegerType(), True), StructField('Column_3', FloatType(), True), StructField('Column_4', DateType(), True), StructField('Column_5', StringType(), True), StructField('Column_6', IntegerType(), True), StructField('Column_7', FloatType(), True), StructField('Column_8', DateType(), True)])

# Final Example Schema (Inheritance) Usage
get_spark_schema(FinalExampleSchema)
# Output: StructType([StructField('Column_1', StringType(), True), StructField('Column_2', IntegerType(), True), StructField('Column_3', FloatType(), True), StructField('Column_4', DateType(), True), StructField('Column_5', StringType(), True), StructField('Column_6', IntegerType(), True), StructField('Column_7', FloatType(), True), StructField('Column_8', DateType(), True)])

The following code demonstrates how to create PySpark DataFrames using the schemas generated from the pandera.pyspark models and displays the structure of these DataFrames.

# Create an empty RDD as a placeholder for data
empty_rdd = spark.sparkContext.emptyRDD()

# Generate the schema from pandera.pyspark schema
schema_1 = get_spark_schema(ExampleSchema_1)
schema_2 = get_spark_schema(ExampleSchema_2)
schema_3 = get_spark_schema(ExampleSchema_3)
schema_final = get_spark_schema(FinalExampleSchema)

# Create a DataFrame with the generated schema
example_df_1 = spark.createDataFrame(empty_rdd, schema_1)
example_df_2 = spark.createDataFrame(empty_rdd, schema_2)
example_df_3 = spark.createDataFrame(empty_rdd, schema_3)
example_df_final = spark.createDataFrame(empty_rdd, schema_final)

# Show the DataFrame structure
example_df_1.printSchema()
example_df_2.printSchema()
example_df_3.printSchema()
example_df_final.printSchema()

# Expected output for example_df_1:
# root
#  |-- Column_1: string (nullable = true)
#  |-- Column_2: integer (nullable = true)

# Expected output for example_df_2:
# root
#  |-- Column_3: float (nullable = true)
#  |-- Column_4: date (nullable = true)

# Expected output for example_df_3:
# root
#  |-- Column_1: string (nullable = true)
#  |-- Column_2: integer (nullable = true)
#  |-- Column_3: float (nullable = true)
#  |-- Column_4: date (nullable = true)
#  |-- Column_5: string (nullable = true)
#  |-- Column_6: integer (nullable = true)
#  |-- Column_7: float (nullable = true)
#  |-- Column_8: date (nullable = true)

# Expected output for example_df_final:
# root
#  |-- Column_1: string (nullable = true)
#  |-- Column_2: integer (nullable = true)
#  |-- Column_3: float (nullable = true)
#  |-- Column_4: date (nullable = true)
#  |-- Column_5: string (nullable = true)
#  |-- Column_6: integer (nullable = true)
#  |-- Column_7: float (nullable = true)
#  |-- Column_8: date (nullable = true)

Garett601 avatar Nov 28 '23 05:11 Garett601

watching... will it work on nested StructFields?

vladgrish avatar Dec 24 '23 18:12 vladgrish

Watching... It would be great to be able to export both complete StructType object or the condensed DDL format, for example:

col0 INT, col1 DOUBLE, col3 STRING

filipeo2-mck avatar Apr 08 '24 17:04 filipeo2-mck

@Garett601 and @vladgrish , I just opened the #1570 PR, could you take a look at it, please?

filipeo2-mck avatar Apr 12 '24 13:04 filipeo2-mck

#1570 was merged, I believe this can be closed @cosmicBboy

filipeo2-mck avatar Jun 21 '24 14:06 filipeo2-mck

fixed by #1570

cosmicBboy avatar Jun 21 '24 16:06 cosmicBboy