`ft_vector_assembler()` crashes when values are missing / `NA`
If there are any NA / missing values in a Spark Data Frame passed to ft_vector_assembler(), the process will crash. Minimal working example and output:
library(sparklyr)
# connect to Spark
sc <- spark_connect(master = "local[*]")
# make a simple data.frame with IntegerType values
df <- iris[1:4]
df[2, 2] <- NA # **comment out this line and it will work fine**
# add nice column names
names(df) <- c("a", "b", "c", "d")
# copy R data.frame to Spark
df_sparklyr <- copy_to(sc, df, overwrite=TRUE)
# add new vector column created from scalar columns
vector <- df_sparklyr %>% ft_vector_assembler(c("d", "b"), "vector_col")
# print new Spark DataFrame to see that vector column is DoubleType
print(head(vector))
Output with the line not commented out:
Error: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 23.0 failed 1 times, most recent failure: Lost task 0.0 in stage 23.0 (TID 25, localhost, executor driver): org.apache.spark.SparkException: Failed to execute user defined function($anonfun$4: (struct<d:double,b:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:121)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Encountered null while assembling a row with handleInvalid = "keep". Consider
removing nulls from dataset or using handleInvalid = "keep" or "skip".
at org.apache.spark.ml.feature.VectorAssembler$$anonfun$assemble$1.apply(VectorAssembler.scala:287)
at org.apache.spark.ml.feature.VectorAssembler$$anonfun$assemble$1.apply(VectorAssembler.scala:255)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:35)
at org.apache.spark.ml.feature.VectorAssembler$.assemble(VectorAssembler.scala:255)
at org.apache.spark.ml.feature.VectorAssembler$$anonfun$4.apply(VectorAssembler.scala:144)
at org.apache.spark.ml.feature.VectorAssembler$$anonfun$4.apply(VectorAssembler.scala:143)
... 21 more
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383)
at org.apache.spark.sql.Dataset$$anonfun$collect$1.apply(Dataset.scala:2782)
at org.apache.spark.sql.Dataset$$anonfun$collect$1.apply(Dataset.scala:2782)
at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364)
at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363)
at org.apache.spark.sql.Dataset.collect(Dataset.scala:2782)
at sparklyr.Utils$.collect(utils.scala:204)
at sparklyr.Utils.collect(utils.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at sparklyr.Invoke.invoke(invoke.scala:147)
at sparklyr.StreamHandler.handleMethodCall(stream.scala:123)
at sparklyr.StreamHandler.read(stream.scala:66)
at sparklyr.BackendHandler.channelRead0(handler.scala:51)
at sparklyr.BackendHandler.channelRead0(handler.scala:4)
at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:340)
at io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:102)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:340)
at io.netty.handler.codec.ByteToMessageDecoder.fireChannelRead(ByteToMessageDecoder.java:310)
at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:284)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:340)
at io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1359)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:362)
at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:348)
at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:935)
at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:138)
at io.netty.channel.nio.NioEventLoop.processS
Output with the line commented out:
# Source: spark<?> [?? x 5]
a b c d vector_col
<dbl> <dbl> <dbl> <dbl> <list>
1 5.1 3.5 1.4 0.2 <dbl [2]>
2 4.9 3 1.4 0.2 <dbl [2]>
3 4.7 3.2 1.3 0.2 <dbl [2]>
4 4.6 3.1 1.5 0.2 <dbl [2]>
5 5 3.6 1.4 0.2 <dbl [2]>
6 5.4 3.9 1.7 0.4 <dbl [2]>
It looks like the vector assembler has a "handleInvalid" parameter which can be set to "keep" or "skip", introduced in Spark 2.4.0.
However, the current sparklyr implementation of the vector assembler is only passing the input/output column name and uid arguments (see line 34), with handleInvalid stuck on the default "keep". Adding this parameter seems like a good feature request for ft_vector_assembler(), depending on how the sparklyr devs want to handle this to stay compatible with older versions of Spark.
In the meantime I guess there's ft_imputer() if that works for handling your actual missing values, or using sparklyr::invoke() to call the original method manually.
Here is an example of using sparklyr::invoke() to set the handleInvalid parameter after creating a vector assembler.
library(sparklyr)
use_modified_vector_assembler <- TRUE
sc <- spark_connect(method = "databricks")
df <- iris[1:4]
df[2, 2] <- NA
names(df) <- c("a", "b", "c", "d")
df_sparklyr <- copy_to(sc, df, overwrite = TRUE)
if (use_modified_vector_assembler) {
# define a vector assembler
my_vector_assembler <- ft_vector_assembler(sc, c("d", "b"), "vector_col")
# set the handleInvalid parameter
my_vector_assembler$.jobj <- invoke(my_vector_assembler$.jobj, "setHandleInvalid", "skip")
# apply the transformer
vector <- ml_transform(my_vector_assembler, df_sparklyr)
} else {
vector <- df_sparklyr %>% ft_vector_assembler(c("d", "b"), "vector_col")
}
print(head(vector))
With use_modified_vector_assembler <- FALSE
Error : org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] User defined
function (`VectorAssembler$$Lambda$12317/1550954274`:
(struct<d:double,b:double>) =>
struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) failed
due to: org.apache.spark.SparkException: Encountered null while assembling a
row with handleInvalid = "error". Consider removing nulls from dataset or using
handleInvalid = "keep" or "skip".. SQLSTATE: 39000
With use_modified_vector_assembler <- TRUE
# Source: spark<?> [?? x 5]
a b c d vector_col
<dbl> <dbl> <dbl> <dbl> <list>
1 5.1 3.5 1.4 0.2 <dbl [2]>
2 4.7 3.2 1.3 0.2 <dbl [2]>
3 4.6 3.1 1.5 0.2 <dbl [2]>
4 5 3.6 1.4 0.2 <dbl [2]>
5 5.4 3.9 1.7 0.4 <dbl [2]>
6 4.6 3.4 1.4 0.3 <dbl [2]>