spark
spark copied to clipboard
[SPARK-40264][ML] add batch_infer_udf function to pyspark.ml.functions
What changes were proposed in this pull request?
Add a batch_infer_udf
function to pyspark.ml.functions
to help users construct a pandas_udf
for ML/DL model inference in Spark. This UDF adds standardized behavior for:
- conversion of the Spark/Pandas DataFrame to numpy arrays.
- batching of the inputs sent to the model predict() function.
- caching of the model and predict() function on the executors.
And the user just needs to provide a predict_batch_fn() function for inference on a batch.
For example, here is some sample code for inference using an MNIST model:
from pyspark.ml.functions import batch_infer_udf
def predict_batch_fn():
import tensorflow as tf
model = tf.keras.models.load_model('/path/to/mnist_model')
def predict(inputs: Union[np.array, Dict[str, np.array]]) -> Union[np.array, Dict[str, np.array]]:
return model.predict(inputs)
return predict
mnist = batch_infer_udf(predict_batch_fn,
return_type=ArrayType(FloatType()),
batch_size=100,
input_tensor_shapes=[[-1, 784]])
df = spark.read.parquet("mnist_data")
preds = df.withColumn("preds", mnist(struct(df.columns))).toPandas()
Why are the changes needed?
To simplify batch inference of ML/DL models in Spark.
Does this PR introduce any user-facing change?
Yes, adds a new function in pyspark.ml.functions
.
How was this patch tested?
Unit tests added.
Can one of the admins verify this patch?
But I think we'd better design and discuss the API first. @mengxr Do you have any suggestions ?
@WeichenXu123 Could you make a pass on the implementation?
Pls also fix the linter failure: https://github.com/leewyang/spark/actions/runs/3397174449/jobs/5649073867#step:16:71
Pls also fix the linter failure: https://github.com/leewyang/spark/actions/runs/3397174449/jobs/5649073867#step:16:71
Updated to latest master, which got rid of the linter error, but it added a new "appveyor" check, which seems to be failing in some SparkR tests. Not sure what to do with that one.
Pls also fix the linter failure: https://github.com/leewyang/spark/actions/runs/3397174449/jobs/5649073867#step:16:71
Updated to latest master, which got rid of the linter error, but it added a new "appveyor" check, which seems to be failing in some SparkR tests. Not sure what to do with that one.
The Ci failure has been fixed by https://github.com/apache/spark/commit/9cd55052ccefc1421d30bfc751e2a013973d3ac6
You can merge master to address the CI failure @leewyang
Some more comments, mainly about doc and adding more data checking. PR is near ready!
@mengxr Could you make a final pass ? The PR is LGTM once all my comments addressed.
BTW, I'm seeing a change in behavior in the pandas_udf
when used with limit
in the latest master branch of spark (vs. 3.3.1), per this example code:
import numpy as np
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
data = np.arange(0, 1000, dtype=np.float64)
pdf = pd.DataFrame(data, columns=['x'])
df = spark.createDataFrame(pdf)
@pandas_udf(returnType=DoubleType())
def times_two(x):
print(x.shape)
return x*2
# 3.3.1: x.shape = (10,)
# master: x.shape = (500,)
df.limit(10).withColumn("x2", times_two("x")).collect()
Not sure if this is a regression or an intentional change, but it does impact performance for this PR, since a given model will be run against 500 rows instead of 10 (even though the final results show only 10 rows). Basically, it looks like the limit
function is being applied after running the pandas_udf
on a full partition, whereas it used to be applied before running the pandas_udf
.
@leewyang Does df.limit(10).cache().withColumn
address the issue ?
@HyukjinKwon Looks like https://github.com/apache/spark/pull/37734#issuecomment-1315678614 is a regression, does spark optimizer changed ?
@mengxr Could you make a final review ?
@WeichenXu123 Yes, using df.limit(10).cache().withColumn
makes it only process 10 rows inside the pandas_udf, which addresses the performance issue, thanks!
Nit:
We can add 2 new examples for:
- Return with struct type containing fields of array type, and
PredictBatchFunction
returns dict with keys matching struct fields. - Return with struct type, but
PredictBatchFunction
returns list of dict.
Since spark 3.4 code freezing is coming soon, I will merge this PR right now, @leewyang Could you file a follow-up PR for my suggested doc change ?
Merged to apache/spark master.