pandera
pandera copied to clipboard
Support conversion from DataFrameModel to PySpark StructType
Is your feature request related to a problem? Please describe.
When creating pandera.pyspark
schemas, if you would like to create a PySpark DataFrame you have to explicitly define the StructType that matches the schema definition. This causes duplication of work and is very time consuming, especially when you have a final schema that inherits from other schemas. This issue/enhancement was suggeted in Support conversion from Pandera DataFrameSchema or DataFrameModel to PySpark StructType #1327.
Describe the solution you'd like
I have written a function that could possibly be included in the pandera package to make this process quick and easy. See the code below for the proposed solution, which also clearly shows the intended behaviour:
import pyspark.sql.types as T
from pandera.pyspark import DataFrameModel
from pyspark.sql.types import (
DateType,
FloatType,
IntegerType,
StringType,
StructField,
StructType,
)
from typing import Type, Set, List
def get_spark_schema(cls: Type[DataFrameModel]) -> StructType:
"""
Generates a Spark StructType schema from a DataFrameModel class and its base classes, using their annotations.
This function iterates over the Method Resolution Order (MRO) of the provided class. It examines each base class
that is a subclass of DataFrameModel (excluding DataFrameModel itself). For each, it extracts annotated fields,
maps them to Spark DataTypes, and creates a StructField for each. The function ensures uniqueness of each field in
the resulting schema.
Parameters
----------
cls : DataFrameModel
The class inheriting from DataFrameModel for which the Spark schema is to be generated.
Returns
-------
StructType
A Spark StructType schema with StructFields for each annotated field in the class hierarchy.
Notes
-----
The function uses guard clauses for reduced nesting and improved readability. All fields in the resulting schema
are marked as nullable.
Examples
--------
>>> class MyDataFrameModel(DataFrameModel):
... name: T.StringType
... age: T.IntegerType
>>> schema = get_spark_schema(MyDataFrameModel)
>>> print(schema)
StructType([StructField('name', StringType(), True),
StructField('age', IntegerType(), True),
StructField('Prediction_Date', DateType(), True)])
"""
fields: List[StructField] = []
added_fields: Set[str] = set()
nullable: bool = True
type_mapping: dict = {
T.StringType: StringType(),
T.IntegerType: IntegerType(),
T.FloatType: FloatType(),
T.DateType: DateType(),
}
for base in cls.__mro__:
if not issubclass(base, DataFrameModel) or base == DataFrameModel:
continue
for attr_name, attr_type in base.__annotations__.items():
spark_type = type_mapping.get(attr_type, None)
if attr_name in added_fields or spark_type is None:
continue
fields.append(StructField(attr_name, spark_type, nullable))
added_fields.add(attr_name)
return StructType(fields)
Describe alternatives you've considered
The alternative is defining a get_spark_schema
classmethod within each schema definition, but this still requires duplication of work and is time consuming.
Additional context
Example Usage and Outputs
import pandera.pyspark as pa
import pyspark.sql.types as T
from pandera.pyspark import DataFrameModel
from pyspark.sql.types import (
DateType,
FloatType,
IntegerType,
StringType,
)
class ExampleSchema_1(DataFrameModel):
"""Schema defining columns."""
Column_1: T.StringType = pa.Field(coerce=True)
Column_2: T.IntegerType = pa.Field(coerce=True)
class ExampleSchema_2(DataFrameModel):
"""Schema defining columns."""
Column_3: T.FloatType = pa.Field(coerce=True)
Column_4: T.DateType = pa.Field(coerce=True)
class ExampleSchema_3(DataFrameModel):
"""Schema defining columns."""
Column_1: T.StringType = pa.Field(coerce=True)
Column_2: T.IntegerType = pa.Field(coerce=True)
Column_3: T.FloatType = pa.Field(coerce=True)
Column_4: T.DateType = pa.Field(coerce=True)
Column_5: T.StringType = pa.Field(coerce=True)
Column_6: T.IntegerType = pa.Field(coerce=True)
Column_7: T.FloatType = pa.Field(coerce=True)
Column_8: T.DateType = pa.Field(coerce=True)
class FinalExampleSchema(ExampleSchema_1, ExampleSchema_2, ExampleSchema_3):
"""Final schema that inherits from all other schemas."""
This section demonstrates the usage of the get_spark_schema
function with different schema examples and shows the corresponding outputs to illustrate its functionality.
# Example Schema 1 Usage
get_spark_schema(ExampleSchema_1)
# Output: StructType([StructField('Column_1', StringType(), True), StructField('Column_2', IntegerType(), True)])
# Example Schema 2 Usage
get_spark_schema(ExampleSchema_2)
# Output: StructType([StructField('Column_3', FloatType(), True), StructField('Column_4', DateType(), True)])
# Example Schema 3 Usage
get_spark_schema(ExampleSchema_3)
# Output: StructType([StructField('Column_1', StringType(), True), StructField('Column_2', IntegerType(), True), StructField('Column_3', FloatType(), True), StructField('Column_4', DateType(), True), StructField('Column_5', StringType(), True), StructField('Column_6', IntegerType(), True), StructField('Column_7', FloatType(), True), StructField('Column_8', DateType(), True)])
# Final Example Schema (Inheritance) Usage
get_spark_schema(FinalExampleSchema)
# Output: StructType([StructField('Column_1', StringType(), True), StructField('Column_2', IntegerType(), True), StructField('Column_3', FloatType(), True), StructField('Column_4', DateType(), True), StructField('Column_5', StringType(), True), StructField('Column_6', IntegerType(), True), StructField('Column_7', FloatType(), True), StructField('Column_8', DateType(), True)])
The following code demonstrates how to create PySpark DataFrames using the schemas generated from the pandera.pyspark
models and displays the structure of these DataFrames.
# Create an empty RDD as a placeholder for data
empty_rdd = spark.sparkContext.emptyRDD()
# Generate the schema from pandera.pyspark schema
schema_1 = get_spark_schema(ExampleSchema_1)
schema_2 = get_spark_schema(ExampleSchema_2)
schema_3 = get_spark_schema(ExampleSchema_3)
schema_final = get_spark_schema(FinalExampleSchema)
# Create a DataFrame with the generated schema
example_df_1 = spark.createDataFrame(empty_rdd, schema_1)
example_df_2 = spark.createDataFrame(empty_rdd, schema_2)
example_df_3 = spark.createDataFrame(empty_rdd, schema_3)
example_df_final = spark.createDataFrame(empty_rdd, schema_final)
# Show the DataFrame structure
example_df_1.printSchema()
example_df_2.printSchema()
example_df_3.printSchema()
example_df_final.printSchema()
# Expected output for example_df_1:
# root
# |-- Column_1: string (nullable = true)
# |-- Column_2: integer (nullable = true)
# Expected output for example_df_2:
# root
# |-- Column_3: float (nullable = true)
# |-- Column_4: date (nullable = true)
# Expected output for example_df_3:
# root
# |-- Column_1: string (nullable = true)
# |-- Column_2: integer (nullable = true)
# |-- Column_3: float (nullable = true)
# |-- Column_4: date (nullable = true)
# |-- Column_5: string (nullable = true)
# |-- Column_6: integer (nullable = true)
# |-- Column_7: float (nullable = true)
# |-- Column_8: date (nullable = true)
# Expected output for example_df_final:
# root
# |-- Column_1: string (nullable = true)
# |-- Column_2: integer (nullable = true)
# |-- Column_3: float (nullable = true)
# |-- Column_4: date (nullable = true)
# |-- Column_5: string (nullable = true)
# |-- Column_6: integer (nullable = true)
# |-- Column_7: float (nullable = true)
# |-- Column_8: date (nullable = true)
watching... will it work on nested StructFields?
Watching... It would be great to be able to export both complete StructType object or the condensed DDL format, for example:
col0 INT, col1 DOUBLE, col3 STRING
@Garett601 and @vladgrish , I just opened the #1570 PR, could you take a look at it, please?
#1570 was merged, I believe this can be closed @cosmicBboy
fixed by #1570