pandantic
pandantic copied to clipboard
Support for Polars Dataframes
pandantic seemed to be such a nice and simple implementation that I decided edit your model to use with Polars Dataframes and figured I would share the results.
I only recently began using polars so there might be more efficient ways, but here were the changes I had to make to your model:
- There is no index, so replaced it using
with_row_count()
to get the row number for errors - Chunk logic can be handled by
iter_slices()
where the n_rows can be determined by the total rows / CPU count - Instead of
to_dict()
, we useiter_rows(named=True)
to pass each row into the validator - We use
filter()
to exclude the error rows if the errors is set to "filter"
from multiprocess import Process, Queue, cpu_count
import polars as pl
import math
import os
from pydantic import BaseModel
import logging
class PolarsModel(BaseModel):
@classmethod
def parse_df(
cls,
dataframe: pl.DataFrame,
errors: str = "raise",
context: dict[str, object] | None = None,
n_jobs: int = 1,
verbose: bool = True,
) -> pl.DataFrame:
errors_index = []
dataframe = dataframe.clone().with_row_count()
logging.info(f"Validating {dataframe.height} rows")
logging.debug(f"Amount of available cores: {cpu_count()}")
if n_jobs != 1:
if n_jobs < 0:
n_jobs = cpu_count()
chunk_size = math.ceil(len(dataframe) / n_jobs)
chunks = list(dataframe.iter_slices(n_rows=chunk_size))
total_chunks = len(chunks)
logging.info(f"Split the dataframe into {total_chunks} chunks to process {chunk_size} rows per chunk.")
processes = []
q = Queue()
for chunk in chunks:
p = Process(target=cls._validate_row, args=(chunk, q, context, verbose), daemon=True)
p.start()
processes.append(p)
num_stops = 0
while num_stops < total_chunks:
index = q.get()
if index is None:
num_stops += 1
else:
errors_index.append(index)
for p in processes:
p.join()
else:
for row in dataframe.iter_rows(named=True):
try:
cls.model_validate(obj=row, context=context)
except Exception as exc:
if verbose:
logging.info(f"Validation error found at row {row['row_nr']}\n{exc}")
errors_index.append(row["row_nr"])
logging.info(f"# invalid rows: {len(errors_index)}")
if len(errors_index) > 0 and errors == "raise":
raise ValueError(f"{len(errors_index)} validation errors found in dataframe.")
if len(errors_index) > 0 and errors == "filter":
return dataframe.filter(~pl.col("row_nr").is_in(errors_index)).drop(columns=["row_nr"])
return dataframe.drop(columns=["row_nr"])
@classmethod
def _validate_row(cls, chunk: pl.DataFrame, q: Queue, context=None, verbose=True) -> None:
for row in chunk.iter_rows(named=True):
try:
cls.model_validate(obj=row, context=context)
except Exception as exc:
if verbose:
logging.info(f"Validation error found at row {row['row_nr']}\n{exc}")
q.put(row["row_nr"])
q.put(None)
I tested this on a dataframe which I duplicated a bunch of times until the row count was > 1 million rows to check if n_jobs was functioning correctly.
With n_jobs = 1:
With n_jobs = 4 (twice as fast):
Example validation error if verbose=True:
Example with errors="filter", the resulting dataframe has the expected rows:
Hi @ldacey,
Thank you so much for trying this out. This is on the "private" roadmap, just after creating a benchmarking tests (as I would like to have a baseline performance for Pandantic versus any other package like Pandera). I think it is nice to have some kind of benchmarking tests, with a example dataframe that contains all the basic cases what you would expect from a regular dataframe (size, columns and different types). That also would help with improvement the performance of your methods, like parsing Polars DataFrames (by improving logic). That would also help with comparing performance between Pandas and Polars when using Pydantic for validation of your DataFrame.
I would actually have another approach, and create private methods within the current pandantic.BaseModel
class. In the parse_df
, logic should be created that determines the DataFrame type and reference the correct private method correctly.
Do you want to add a PR, I am open for contributes obviously. If you don't feel like it, that is also fine. I was planning to create the use of Polars DataFrame anyway in the future.
Cool, yeah I did not know if you had any interest in other Dataframe libraries or what approach you would take. It seems like a majority of the code would be the same, with some differences in how chunking / row iteration gets done.
I was playing around earlier just to see what was possible.
-
Split the single and multicore functionality into different methods
-
I removed verbose logging per error because it caused my Jupyter notebook to crash when I spammed too many errors at once, instead I decided to append the errors to a list and log that full list if verbose=True.
-
The format
.errors()
ValidationError uses is pretty clear, so I decided to just add therow_nr
to that dictionary. The custom ValidationException prints out the total errors and returns a list of those errors if needed. Aconvert_errors
function reformats some of the output. I didn't check the potential performance impact, but just thought it was neat. https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages
try:
test = ExampleModel.parse_df(df, errors="raise", n_jobs=4, verbose=False)
except ValidationException as e:
print(e)
print(e.errors)
[2023-09-08T10:30:35.330+0800] {models.py:83} INFO - Validating 768 rows
49 validation errors found in dataframe.
[[{'type': 'greater_than_equal', 'loc': 'Employee CSAT', 'msg': 'Input should be greater than or equal to 3', 'input': 1, 'ctx': {'ge': 3}, 'row': 5}], [{'type': 'greater_than_equal', 'loc': 'Employee CSAT', 'msg': 'Input should be greater than or eq
ual to 3', 'input': 1, 'ctx': {'ge': 3}, 'row': 10}],
For example, I can technically read those errors into a Dataframe and then save it as a file or email an alert with the details etc. (This was mostly just for fun and I am not suggesting it should be added to your library)
new = pl.from_dicts(e.flatten_errors())
print(new.filter(pl.col("row").is_duplicated()))
Here is the latest version I was using:
import math
from typing import Any, Literal
import polars as pl
from multiprocess import Process, Queue, cpu_count
from pydantic import BaseModel, ValidationError
from pydantic_core import ErrorDetails
class ValidationException(Exception):
"""Exception raised when validation fails and returns a list of all errors"""
def __init__(self, errors):
super().__init__(f"{len(errors)} validation errors found in dataframe.")
self.errors = errors
def flatten_errors(self):
"""Flatten the list of error groups into a single list of errors"""
return [
{
"row": error["row"],
"loc": error["loc"],
"input": error["input"],
"type": error["type"],
"ctx": error["ctx"],
"msg": error["msg"],
}
for error_group in self.errors
for error in error_group
]
def convert_errors(e: ValidationError, row_index: int) -> list[ErrorDetails]:
"""Removes the url field, adds the row field, and converts the loc field to a string
in the ErrorDetails object
Args:
e: The ValidationError object
row_index: The row number of the DataFrame that failed validation
"""
new_errors: list[ErrorDetails] = []
for error in e.errors():
del error["url"]
error["row"] = row_index
if isinstance(error["loc"], (tuple, list)):
error["loc"] = ".".join(map(str, error["loc"]))
else:
error["loc"] = str(error["loc"])
new_errors.append(error)
return new_errors
class PolarsBaseModel(BaseModel):
"""Base class for all Polars schema pydantic models"""
@classmethod
def parse_df(
cls,
dataframe: pl.DataFrame,
errors: Literal["raise", "filter"] = "raise",
context: dict[str, object] | None = None,
n_jobs: int = 1,
verbose: bool = True,
) -> pl.DataFrame:
"""Validate a DataFrame using the schema defined in the Pydantic model.
Converts each row in the DataFrame to a dictionary prior to validation. If
n_jobs > 1, the DataFrame is split into chunks and each chunk is validated in a
separate process. Any errors are appended to a list for logging or filtering.
Args:
dataframe (pl.DataFrame): The DataFrame to validate.
errors (str, optional): How to handle validation errors.
Defaults to "raise".
context (Optional[dict[str, Any], None], optional): The context to use for
validation.
n_jobs (int, optional): The number of processes to use for validation.
Defaults to 1.
verbose (bool, optional): Whether to log validation errors.
Defaults to True.
Returns:
pl.DataFrame: The DataFrame with valid rows in case of errors="filter".
"""
dataframe = dataframe.clone().with_row_count()
errors_index = []
error_details = []
logging.info(f"Validating {dataframe.height} rows")
logging.debug(f"Amount of available cores: {cpu_count()}")
if n_jobs != 1:
errors_index = cls._validate_multicore(
dataframe, n_jobs, context, error_details
)
else:
errors_index = cls._validate_singlecore(dataframe, context, error_details)
if len(errors_index) > 0:
if errors == "raise":
if verbose:
logging.info(error_details)
raise ValidationException(error_details)
elif errors == "filter":
if verbose:
logging.info(error_details)
return dataframe.filter(~pl.col("row_nr").is_in(errors_index)).drop(
columns=["row_nr"]
)
return dataframe.drop(columns=["row_nr"])
@classmethod
def _validate_singlecore(
cls,
dataframe: pl.DataFrame,
context: dict[str, Any] | None,
error_details: list,
):
"""Validates each row of a DataFrame in dictionary format in a single process.
Args:
dataframe: DataFrame to validate
context: Context to pass to the model_validate method
error_details: List to store the error details
"""
errors_index = []
for row in dataframe.iter_rows(named=True):
try:
cls.model_validate(obj=row, context=context)
except ValidationError as exc:
exception = convert_errors(exc, row["row_nr"])
error_details.append(exception)
errors_index.append(row["row_nr"])
return errors_index
@classmethod
def _validate_multicore(
cls,
dataframe: pl.DataFrame,
n_jobs: int,
context: dict[str, Any] | None,
error_details: list,
):
"""Split the dataframe into chunks and validate each chunk in a separate process
where each chunk is total rows / n_jobs in size. Each chunk is validated in a
separate process. Any errors are appended to a list and returned.
Args:
dataframe: DataFrame to validate
n_jobs: Number of processes to use
context: Context to pass to the model_validate method
error_details: List to store the error details
"""
errors_index = []
if n_jobs < 0:
n_jobs = cpu_count()
chunk_size = math.ceil(len(dataframe) / n_jobs)
chunks = list(dataframe.iter_slices(n_rows=chunk_size))
total_chunks = len(chunks)
processes = []
q = Queue()
for chunk in chunks:
p = Process(
target=cls._validate_row,
args=(chunk, q, context),
daemon=True,
)
p.start()
processes.append(p)
num_stops = 0
while num_stops < total_chunks:
exceptions = q.get()
if exceptions is None:
num_stops += 1
else:
errors_index.append(exceptions["row_nr"])
error_details.append(exceptions["error_detail"])
for p in processes:
p.join()
return errors_index
@classmethod
def _validate_row(
cls,
chunk: pl.DataFrame,
q: Queue,
context: dict[str, Any] | None,
) -> None:
"""Validates each row of a DataFrame in dictionary format as a separate
parallel process.
Args:
chunk: DataFrame chunk to validate
q: Queue to store the row numbers of the rows that failed validation
context: Context to pass to the model_validate method
verbose: Whether to log the error details
error_details: List to store the error details
"""
for row in chunk.iter_rows(named=True):
try:
cls.model_validate(obj=row, context=context)
except ValidationError as exc:
exception = convert_errors(exc, row["row_nr"])
q.put({"row_nr": row["row_nr"], "error_detail": exception})
q.put(None)
@ldacey we are working on a refactor that will use the dependency injection design pattern to, eventually, handle as many combinations of table libraries (spark/pandas/polars/dask/etc) and "schema" libraries (pydantic/attrs/dataclass).
It seems like you put some real thought in this and are a polars user. We are starting with pandas as it has the widest use base, but if you are interested in this new round of effort, we would love if you join (perhaps for our next call) and help with the polars implementation or at least provide perspective of what polars users might like.