spark icon indicating copy to clipboard operation
spark copied to clipboard

[SPARK-40264][ML] add batch_infer_udf function to pyspark.ml.functions

Open leewyang opened this issue 2 years ago • 2 comments

What changes were proposed in this pull request?

Add a batch_infer_udf function to pyspark.ml.functions to help users construct a pandas_udf for ML/DL model inference in Spark. This UDF adds standardized behavior for:

  • conversion of the Spark/Pandas DataFrame to numpy arrays.
  • batching of the inputs sent to the model predict() function.
  • caching of the model and predict() function on the executors.

And the user just needs to provide a predict_batch_fn() function for inference on a batch.

For example, here is some sample code for inference using an MNIST model:

from pyspark.ml.functions import batch_infer_udf

def predict_batch_fn():
    import tensorflow as tf
    model = tf.keras.models.load_model('/path/to/mnist_model')
    def predict(inputs: Union[np.array, Dict[str, np.array]]) -> Union[np.array, Dict[str, np.array]]:
        return model.predict(inputs)
    return predict

mnist = batch_infer_udf(predict_batch_fn,
                        return_type=ArrayType(FloatType()),
                        batch_size=100,
                        input_tensor_shapes=[[-1, 784]])

df = spark.read.parquet("mnist_data")
preds = df.withColumn("preds", mnist(struct(df.columns))).toPandas()

Why are the changes needed?

To simplify batch inference of ML/DL models in Spark.

Does this PR introduce any user-facing change?

Yes, adds a new function in pyspark.ml.functions.

How was this patch tested?

Unit tests added.

leewyang avatar Aug 30 '22 20:08 leewyang

Can one of the admins verify this patch?

AmplabJenkins avatar Aug 31 '22 18:08 AmplabJenkins

But I think we'd better design and discuss the API first. @mengxr Do you have any suggestions ?

WeichenXu123 avatar Sep 01 '22 11:09 WeichenXu123

@WeichenXu123 Could you make a pass on the implementation?

mengxr avatar Oct 03 '22 22:10 mengxr

Pls also fix the linter failure: https://github.com/leewyang/spark/actions/runs/3397174449/jobs/5649073867#step:16:71

WeichenXu123 avatar Nov 07 '22 13:11 WeichenXu123

Pls also fix the linter failure: https://github.com/leewyang/spark/actions/runs/3397174449/jobs/5649073867#step:16:71

Updated to latest master, which got rid of the linter error, but it added a new "appveyor" check, which seems to be failing in some SparkR tests. Not sure what to do with that one.

leewyang avatar Nov 08 '22 00:11 leewyang

Pls also fix the linter failure: https://github.com/leewyang/spark/actions/runs/3397174449/jobs/5649073867#step:16:71

Updated to latest master, which got rid of the linter error, but it added a new "appveyor" check, which seems to be failing in some SparkR tests. Not sure what to do with that one.

The Ci failure has been fixed by https://github.com/apache/spark/commit/9cd55052ccefc1421d30bfc751e2a013973d3ac6

You can merge master to address the CI failure @leewyang

WeichenXu123 avatar Nov 08 '22 03:11 WeichenXu123

Some more comments, mainly about doc and adding more data checking. PR is near ready!

WeichenXu123 avatar Nov 08 '22 05:11 WeichenXu123

@mengxr Could you make a final pass ? The PR is LGTM once all my comments addressed.

WeichenXu123 avatar Nov 11 '22 15:11 WeichenXu123

BTW, I'm seeing a change in behavior in the pandas_udf when used with limit in the latest master branch of spark (vs. 3.3.1), per this example code:

import numpy as np
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType

data = np.arange(0, 1000, dtype=np.float64)
pdf = pd.DataFrame(data, columns=['x'])
df = spark.createDataFrame(pdf)

@pandas_udf(returnType=DoubleType())
def times_two(x):
    print(x.shape)
    return x*2
    
# 3.3.1: x.shape = (10,)
# master: x.shape = (500,)
df.limit(10).withColumn("x2", times_two("x")).collect()

Not sure if this is a regression or an intentional change, but it does impact performance for this PR, since a given model will be run against 500 rows instead of 10 (even though the final results show only 10 rows). Basically, it looks like the limit function is being applied after running the pandas_udf on a full partition, whereas it used to be applied before running the pandas_udf.

leewyang avatar Nov 15 '22 18:11 leewyang

@leewyang Does df.limit(10).cache().withColumn address the issue ?

@HyukjinKwon Looks like https://github.com/apache/spark/pull/37734#issuecomment-1315678614 is a regression, does spark optimizer changed ?

@mengxr Could you make a final review ?

WeichenXu123 avatar Dec 04 '22 13:12 WeichenXu123

@WeichenXu123 Yes, using df.limit(10).cache().withColumn makes it only process 10 rows inside the pandas_udf, which addresses the performance issue, thanks!

leewyang avatar Dec 04 '22 18:12 leewyang

Nit:

We can add 2 new examples for:

  • Return with struct type containing fields of array type, and PredictBatchFunction returns dict with keys matching struct fields.
  • Return with struct type, but PredictBatchFunction returns list of dict.

Since spark 3.4 code freezing is coming soon, I will merge this PR right now, @leewyang Could you file a follow-up PR for my suggested doc change ?

WeichenXu123 avatar Jan 16 '23 13:01 WeichenXu123

Merged to apache/spark master.

WeichenXu123 avatar Jan 16 '23 13:01 WeichenXu123