SynapseML
SynapseML copied to clipboard
VowpalWabbitClassifier does not work with --oaa (One Against All) argument
Describe the bug Vowpal Wabbit's One Against All classifier does not work via the MMLSpark interface.
To Reproduce
val vwClassifier = new VowpalWabbitClassifier()
.setFeaturesCol("features")
.setLabelCol("label")
.setProbabilityCol("predictedProb")
.setPredictionCol("predictedLabel")
.setRawPredictionCol("rawPrediction")
.setArgs("--oaa=2 --quiet --holdout_off")
features is a column of sparse vectors (constructed via VowpalWabbitFeaturizer in my case), label is a column of integers with values {1, 2}.
Expected behavior
val predictions = vwClassifier.fit(trainDF).transform(testDF)
predictions.show
would show my testDF with predictedLabel column containing predictions.
Info (please complete the following information):
- MMLSpark Version: 1.0.0-rc1
- Spark Version: 2.4.3
- Spark Platform: AWS EMR 5.26.0 (Zeppelin 0.8.1)
** Stacktrace**
org.apache.spark.SparkException: Job aborted due to stage failure: Task 95 in stage 46.0 failed 4 times, most recent failure: Lost task 95.3 in stage 46.0 (TID 2609, ip-10-5-29-73.ec2.internal, executor 7): org.apache.spark.SparkException: Failed to execute user defined function($anonfun$2: (struct<features:struct<type:tinyint,size:int,indices:array<int>,values:array<double>>>) => double)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:291)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:283)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:121)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:84)
at com.microsoft.ml.spark.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:61)
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
... 21 more
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2041)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2029)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2028)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2028)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:966)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2262)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2211)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2200)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:777)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:401)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364)
at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2544)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2758)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
at org.apache.spark.sql.Dataset.show(Dataset.scala:745)
at org.apache.spark.sql.Dataset.show(Dataset.scala:704)
at org.apache.spark.sql.Dataset.show(Dataset.scala:713)
... 51 elided
Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$2: (struct<features:struct<type:tinyint,size:int,indices:array<int>,values:array<double>>>) => double)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:291)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:283)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:121)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
... 3 more
Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:84)
at com.microsoft.ml.spark.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:61)
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
... 21 more
To me, it looks like the Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:84) is the root cause. Could it be that --oaa outputs integers instead of doubles expected by MMLSpark?
Additional context For context, this works fine in my setup on the same dataset with the same VowpalWabbitFeaturizer (although I have to convert labels to {1, 0}):
val vwClassifier = new VowpalWabbitClassifier()
.setFeaturesCol("features")
.setLabelCol("label")
.setProbabilityCol("predictedProb")
.setPredictionCol("predictedLabel")
.setRawPredictionCol("rawPrediction")
.setArgs("--loss_function=logistic --link=logistic --quiet --holdout_off")
Hey, why am I still getting this issue?
@Ibtastic what version of mmlspark are you using? Can you open a new issue with stack trace or reopen this issue, it is really hard to track already closed issues (I often miss them). Also adding @eisber .
@imatiach-msft I do not want to open another issue because the error is the same as mentioned by this thread's author. So is the stack trace.
- mmlspark version - 1.0.0-rc3
- spark - 2.4.5
- I am using spark on a Databricks Cluster
@Ibtastic are you sure you are seeing exact same stack trace, with same line numbers as above? The line numbers should change I think since the code has changed. The issue was supposedly fixed with this PR: https://github.com/Azure/mmlspark/pull/817/files That PR should have been in the rc3 release. rc3 release was in October 2020 (https://github.com/Azure/mmlspark/releases) and that PR was merged in March 2020. Let me reopen this issue, maybe @eisber has more info on this issue since he wrote the VW wrappers
I will paste the stack trace in a while. Also, I thought of trying this with Spark 3.1.1, I got java.lang.NoClassDefFoundError . Is VowpalWabbit not implemented for this spark version?
@Ibtastic spark 3.1 is only supported on latest master, you can use any master build
please try this walkthrough with pictures on databricks: https://docs.microsoft.com/en-us/azure/cognitive-services/big-data/getting-started#azure-databricks for spark 2.4.5 you can use rc1 to rc3 releases. For latest spark >3.0 you will need to use a build from master:

For example:
Maven Coordinates com.microsoft.ml.spark:mmlspark_2.12:1.0.0-rc3-80-b704515f-SNAPSHOT Maven Resolver https://mmlspark.azureedge.net/maven
@imatiach-msft Thanks for pointing that out! Here's the stacktrace for spark 2.4.5:
org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2362) at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2350) at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2349) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2349) at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:1102) at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:1102) at scala.Option.foreach(Option.scala:257) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1102) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2582) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2529) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2517) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:897) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2282) at org.apache.spark.sql.execution.collect.Collector.runSparkJobs(Collector.scala:270) at org.apache.spark.sql.execution.collect.Collector.collect(Collector.scala:280) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:80) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:86) at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:508) at org.apache.spark.sql.execution.CollectLimitExec.executeCollectResult(limit.scala:57) at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectResult(Dataset.scala:2890) at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3508) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2619) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2619) at org.apache.spark.sql.Dataset$$anonfun$54.apply(Dataset.scala:3492) at org.apache.spark.sql.Dataset$$anonfun$54.apply(Dataset.scala:3487) at org.apache.spark.sql.execution.SQLExecution$$anonfun$withCustomExecutionEnv$1.apply(SQLExecution.scala:113) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:243) at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:99) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:173) at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withAction(Dataset.scala:3487) at org.apache.spark.sql.Dataset.head(Dataset.scala:2619) at org.apache.spark.sql.Dataset.take(Dataset.scala:2833) at org.apache.spark.sql.Dataset.getRows(Dataset.scala:266) at org.apache.spark.sql.Dataset.showString(Dataset.scala:303) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380) at py4j.Gateway.invoke(Gateway.java:295) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:251) at java.lang.Thread.run(Thread.java:748) Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$2: (struct<features:struct<type:tinyint,size:int,indices:array
,values:array >>) => double) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:640) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:55) at org.apache.spark.scheduler.Task.doRunTask(Task.scala:140) at org.apache.spark.scheduler.Task.run(Task.scala:113) at org.apache.spark.executor.Executor$TaskRunner$$anonfun$13.apply(Executor.scala:537) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1541) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:543) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:98) at com.microsoft.ml.spark.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:62) at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:54) at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:54) ... 15 more
adding @jackgerrits any idea?
Have there been any updates on this issue? I'm seeing the same error. If this has been resolved, would it be possible for someone to provide a simple working example?
PySpark Version: 3.1.2 Spark config: spark.jars.repositories: "https://mmlspark.azureedge.net/maven" spark.jars.packages: "com.microsoft.azure:synapseml_2.12:0.9.5-13-d1b51517-SNAPSHOT" spark.jars.excludes: "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalatest:scalatest_2.12" Platform: EMR 6.5.0 Stack Trace:
pred = model.transform(test_df_feat) only testing Num weight bits = 19 learning rate = 0.5 initial_t = 0 power_t = 0.5 using no cache Reading datafile = num sources = 1
In [13]: pred.show(10)
22/05/26 21:47:16 WARN TaskSetManager: Lost task 0.0 in stage 4.0 (TID 440) (ip-10-68-16-106.ec2.internal executor 1): org.apache.spark.SparkException: Failed to execute user defined function(VowpalWabbitBaseModel$$Lambda$3291/1845615700: (struct<features:struct<type:tinyint,size:int,indices:array
22/05/26 21:47:19 ERROR TaskSetManager: Task 0 in stage 4.0 failed 4 times; aborting job
Py4JJavaError Traceback (most recent call last)
~/.venv/default/lib64/python3.7/site-packages/pyspark/sql/dataframe.py in show(self, n, truncate, vertical) 482 """ 483 if isinstance(truncate, bool) and truncate: --> 484 print(self._jdf.showString(n, 20, vertical)) 485 else: 486 print(self._jdf.showString(n, int(truncate), vertical))
~/.venv/default/lib64/python3.7/site-packages/py4j/java_gateway.py in call(self, *args) 1303 answer = self.gateway_client.send_command(command) 1304 return_value = get_return_value( -> 1305 answer, self.gateway_client, self.target_id, self.name) 1306 1307 for temp_arg in temp_args:
~/.venv/default/lib64/python3.7/site-packages/pyspark/sql/utils.py in deco(*a, **kw) 109 def deco(*a, **kw): 110 try: --> 111 return f(*a, **kw) 112 except py4j.protocol.Py4JJavaError as e: 113 converted = convert_exception(e.java_exception)
~/.venv/default/lib64/python3.7/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 326 raise Py4JJavaError( 327 "An error occurred while calling {0}{1}{2}.\n". --> 328 format(target_id, ".", name), value) 329 else: 330 raise Py4JError(
Py4JJavaError: An error occurred while calling o398.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 4.0 failed 4 times, most recent failure: Lost task 0.3 in stage 4.0 (TID 443) (ip-10-68-16-106.ec2.internal executor 2): org.apache.spark.SparkException: Failed to execute user defined function(VowpalWabbitBaseModel$$Lambda$3291/1845615700: (struct<features:struct<type:tinyint,size:int,indices:array
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2470)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2419)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2418)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2418)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1125)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1125)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1125)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2684)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2626)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2615)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:914)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2241)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2262)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2281)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:494)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:447)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:47)
at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3760)
at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2763)
at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3751)
at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:107)
at org.apache.spark.sql.execution.SQLExecution$.withTracker(SQLExecution.scala:232)
at org.apache.spark.sql.execution.SQLExecution$.executeQuery$1(SQLExecution.scala:110)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:135)
at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:107)
at org.apache.spark.sql.execution.SQLExecution$.withTracker(SQLExecution.scala:232)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:135)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:253)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:134)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:68)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3749)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2763)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2970)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:303)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:340)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:750)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function(VowpalWabbitBaseModel$$Lambda$3291/1845615700: (struct<features:struct<type:tinyint,size:int,indices:array
It's not implemented yet: https://github.com/microsoft/SynapseML/blob/9d16166314ef42767e1c50b9c831c163050c15d3/vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitClassifier.scala#L64
let me follow-up w/ Jack. is there a minimal dataset we can use to repro?
Modified from the adult census example:
import pyspark.sql.functions as f
from synapse.ml.vw import VowpalWabbitFeaturizer, VowpalWabbitClassifier
df = spark.read.format("csv")\
.option("header", True)\
.option("inferSchema", True)\
.load("path/to/adult.csv") #download from kaggle
data = df.toDF(*(c.replace('.', '-') for c in df.columns))\
.select(["education", "marital-status", "hours-per-week", "income"])
#create binary label
data = data.withColumn("label", f.when(f.col("income").contains("<"), 0.0).otherwise(1.0)).repartition(1)
#create random multiclass label
num_classes = 5
data = data.withColumn("random_label", f.round(f.rand()*(num_classes-1), 0))
vw_featurizer = VowpalWabbitFeaturizer(inputCols=["education", "marital-status", "hours-per-week"],
outputCol="features")
data = vw_featurizer.transform(data)
#fit binary classifier
binary_args = "--loss_function=logistic --quiet --holdout_off"
binary_model = VowpalWabbitClassifier(featuresCol="features",
labelCol="label",
args=binary_args,
numPasses=10)
binary_model.fit(data).transform(data).show(10, False) #works like a charm
#fit multiclass classifier
multi_args = f"--loss_function=logistic --quiet --holdout_off --oaa={num_classes}"
multi_model = VowpalWabbitClassifier(featuresCol="features",
labelCol="random_label",
args=multi_args,
numPasses=10)
multi_model.fit(data).transform(data).show(10, False) #gives the stack trace
Thanks for the effort on the PR! I look forward to test driving once it gets merged :)