SynapseML icon indicating copy to clipboard operation
SynapseML copied to clipboard

VowpalWabbitClassifier does not work with --oaa (One Against All) argument

Open patkovskyi opened this issue 5 years ago • 13 comments

Describe the bug Vowpal Wabbit's One Against All classifier does not work via the MMLSpark interface.

To Reproduce

val vwClassifier = new VowpalWabbitClassifier()
        .setFeaturesCol("features")
        .setLabelCol("label")
        .setProbabilityCol("predictedProb")
        .setPredictionCol("predictedLabel")
        .setRawPredictionCol("rawPrediction")
        .setArgs("--oaa=2 --quiet --holdout_off")

features is a column of sparse vectors (constructed via VowpalWabbitFeaturizer in my case), label is a column of integers with values {1, 2}.

Expected behavior

val predictions = vwClassifier.fit(trainDF).transform(testDF)
predictions.show

would show my testDF with predictedLabel column containing predictions.

Info (please complete the following information):

  • MMLSpark Version: 1.0.0-rc1
  • Spark Version: 2.4.3
  • Spark Platform: AWS EMR 5.26.0 (Zeppelin 0.8.1)

** Stacktrace**

org.apache.spark.SparkException: Job aborted due to stage failure: Task 95 in stage 46.0 failed 4 times, most recent failure: Lost task 95.3 in stage 46.0 (TID 2609, ip-10-5-29-73.ec2.internal, executor 7): org.apache.spark.SparkException: Failed to execute user defined function($anonfun$2: (struct<features:struct<type:tinyint,size:int,indices:array<int>,values:array<double>>>) => double)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:291)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:283)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction
	at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:84)
	at com.microsoft.ml.spark.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:61)
	at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
	at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
	... 21 more

Driver stacktrace:
  at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2041)
  at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2029)
  at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2028)
  at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
  at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
  at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2028)
  at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
  at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
  at scala.Option.foreach(Option.scala:257)
  at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:966)
  at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2262)
  at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2211)
  at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2200)
  at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
  at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:777)
  at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
  at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
  at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
  at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:401)
  at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
  at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383)
  at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
  at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
  at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364)
  at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
  at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
  at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
  at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363)
  at org.apache.spark.sql.Dataset.head(Dataset.scala:2544)
  at org.apache.spark.sql.Dataset.take(Dataset.scala:2758)
  at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
  at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
  at org.apache.spark.sql.Dataset.show(Dataset.scala:745)
  at org.apache.spark.sql.Dataset.show(Dataset.scala:704)
  at org.apache.spark.sql.Dataset.show(Dataset.scala:713)
  ... 51 elided
Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$2: (struct<features:struct<type:tinyint,size:int,indices:array<int>,values:array<double>>>) => double)
  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
  at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
  at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:291)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:283)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
  at org.apache.spark.scheduler.Task.run(Task.scala:121)
  at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
  at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
  ... 3 more
Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction
  at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:84)
  at com.microsoft.ml.spark.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:61)
  at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
  at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:49)
  ... 21 more

To me, it looks like the Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:84) is the root cause. Could it be that --oaa outputs integers instead of doubles expected by MMLSpark?

Additional context For context, this works fine in my setup on the same dataset with the same VowpalWabbitFeaturizer (although I have to convert labels to {1, 0}):

val vwClassifier = new VowpalWabbitClassifier()
        .setFeaturesCol("features")
        .setLabelCol("label")
        .setProbabilityCol("predictedProb")
        .setPredictionCol("predictedLabel")
        .setRawPredictionCol("rawPrediction")
        .setArgs("--loss_function=logistic --link=logistic --quiet --holdout_off")

AB#1166568

patkovskyi avatar Feb 17 '20 13:02 patkovskyi

Hey, why am I still getting this issue?

Ibtastic avatar May 11 '21 07:05 Ibtastic

@Ibtastic what version of mmlspark are you using? Can you open a new issue with stack trace or reopen this issue, it is really hard to track already closed issues (I often miss them). Also adding @eisber .

imatiach-msft avatar May 11 '21 14:05 imatiach-msft

@imatiach-msft I do not want to open another issue because the error is the same as mentioned by this thread's author. So is the stack trace.

  • mmlspark version - 1.0.0-rc3
  • spark - 2.4.5
  • I am using spark on a Databricks Cluster

Ibtastic avatar May 11 '21 15:05 Ibtastic

@Ibtastic are you sure you are seeing exact same stack trace, with same line numbers as above? The line numbers should change I think since the code has changed. The issue was supposedly fixed with this PR: https://github.com/Azure/mmlspark/pull/817/files That PR should have been in the rc3 release. rc3 release was in October 2020 (https://github.com/Azure/mmlspark/releases) and that PR was merged in March 2020. Let me reopen this issue, maybe @eisber has more info on this issue since he wrote the VW wrappers

imatiach-msft avatar May 11 '21 16:05 imatiach-msft

I will paste the stack trace in a while. Also, I thought of trying this with Spark 3.1.1, I got java.lang.NoClassDefFoundError . Is VowpalWabbit not implemented for this spark version?

Ibtastic avatar May 11 '21 16:05 Ibtastic

@Ibtastic spark 3.1 is only supported on latest master, you can use any master build

please try this walkthrough with pictures on databricks: https://docs.microsoft.com/en-us/azure/cognitive-services/big-data/getting-started#azure-databricks for spark 2.4.5 you can use rc1 to rc3 releases. For latest spark >3.0 you will need to use a build from master:

image

For example:

Maven Coordinates com.microsoft.ml.spark:mmlspark_2.12:1.0.0-rc3-80-b704515f-SNAPSHOT Maven Resolver https://mmlspark.azureedge.net/maven

imatiach-msft avatar May 11 '21 16:05 imatiach-msft

@imatiach-msft Thanks for pointing that out! Here's the stacktrace for spark 2.4.5:

org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2362) at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2350) at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2349) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2349) at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:1102) at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:1102) at scala.Option.foreach(Option.scala:257) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1102) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2582) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2529) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2517) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:897) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2282) at org.apache.spark.sql.execution.collect.Collector.runSparkJobs(Collector.scala:270) at org.apache.spark.sql.execution.collect.Collector.collect(Collector.scala:280) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:80) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:86) at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:508) at org.apache.spark.sql.execution.CollectLimitExec.executeCollectResult(limit.scala:57) at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectResult(Dataset.scala:2890) at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3508) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2619) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2619) at org.apache.spark.sql.Dataset$$anonfun$54.apply(Dataset.scala:3492) at org.apache.spark.sql.Dataset$$anonfun$54.apply(Dataset.scala:3487) at org.apache.spark.sql.execution.SQLExecution$$anonfun$withCustomExecutionEnv$1.apply(SQLExecution.scala:113) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:243) at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:99) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:173) at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withAction(Dataset.scala:3487) at org.apache.spark.sql.Dataset.head(Dataset.scala:2619) at org.apache.spark.sql.Dataset.take(Dataset.scala:2833) at org.apache.spark.sql.Dataset.getRows(Dataset.scala:266) at org.apache.spark.sql.Dataset.showString(Dataset.scala:303) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380) at py4j.Gateway.invoke(Gateway.java:295) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:251) at java.lang.Thread.run(Thread.java:748) Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$2: (struct<features:struct<type:tinyint,size:int,indices:array,values:array>>) => double) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:640) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:55) at org.apache.spark.scheduler.Task.doRunTask(Task.scala:140) at org.apache.spark.scheduler.Task.run(Task.scala:113) at org.apache.spark.executor.Executor$TaskRunner$$anonfun$13.apply(Executor.scala:537) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1541) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:543) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$class.predictInternal(VowpalWabbitBaseModel.scala:98) at com.microsoft.ml.spark.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:62) at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:54) at com.microsoft.ml.spark.vw.VowpalWabbitBaseModel$$anonfun$2.apply(VowpalWabbitBaseModel.scala:54) ... 15 more

Ibtastic avatar May 11 '21 16:05 Ibtastic

adding @jackgerrits any idea?

eisber avatar May 11 '21 16:05 eisber

Have there been any updates on this issue? I'm seeing the same error. If this has been resolved, would it be possible for someone to provide a simple working example?

PySpark Version: 3.1.2 Spark config: spark.jars.repositories: "https://mmlspark.azureedge.net/maven" spark.jars.packages: "com.microsoft.azure:synapseml_2.12:0.9.5-13-d1b51517-SNAPSHOT" spark.jars.excludes: "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalatest:scalatest_2.12" Platform: EMR 6.5.0 Stack Trace:

pred = model.transform(test_df_feat) only testing Num weight bits = 19 learning rate = 0.5 initial_t = 0 power_t = 0.5 using no cache Reading datafile = num sources = 1

In [13]: pred.show(10) 22/05/26 21:47:16 WARN TaskSetManager: Lost task 0.0 in stage 4.0 (TID 440) (ip-10-68-16-106.ec2.internal executor 1): org.apache.spark.SparkException: Failed to execute user defined function(VowpalWabbitBaseModel$$Lambda$3291/1845615700: (struct<features:struct<type:tinyint,size:int,indices:array,values:array>>) => double) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_subExpr_1$(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:35) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:907) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:359) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373) at org.apache.spark.rdd.RDD.iterator(RDD.scala:337) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:131) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:750) Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.predictInternal(VowpalWabbitBaseModel.scala:103) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.predictInternal$(VowpalWabbitBaseModel.scala:96) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:65) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.$anonfun$transformImplInternal$2(VowpalWabbitBaseModel.scala:56) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.$anonfun$transformImplInternal$2$adapted(VowpalWabbitBaseModel.scala:56) ... 18 more

22/05/26 21:47:19 ERROR TaskSetManager: Task 0 in stage 4.0 failed 4 times; aborting job

Py4JJavaError Traceback (most recent call last) in ----> 1 pred.show(10)

~/.venv/default/lib64/python3.7/site-packages/pyspark/sql/dataframe.py in show(self, n, truncate, vertical) 482 """ 483 if isinstance(truncate, bool) and truncate: --> 484 print(self._jdf.showString(n, 20, vertical)) 485 else: 486 print(self._jdf.showString(n, int(truncate), vertical))

~/.venv/default/lib64/python3.7/site-packages/py4j/java_gateway.py in call(self, *args) 1303 answer = self.gateway_client.send_command(command) 1304 return_value = get_return_value( -> 1305 answer, self.gateway_client, self.target_id, self.name) 1306 1307 for temp_arg in temp_args:

~/.venv/default/lib64/python3.7/site-packages/pyspark/sql/utils.py in deco(*a, **kw) 109 def deco(*a, **kw): 110 try: --> 111 return f(*a, **kw) 112 except py4j.protocol.Py4JJavaError as e: 113 converted = convert_exception(e.java_exception)

~/.venv/default/lib64/python3.7/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 326 raise Py4JJavaError( 327 "An error occurred while calling {0}{1}{2}.\n". --> 328 format(target_id, ".", name), value) 329 else: 330 raise Py4JError(

Py4JJavaError: An error occurred while calling o398.showString. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 4.0 failed 4 times, most recent failure: Lost task 0.3 in stage 4.0 (TID 443) (ip-10-68-16-106.ec2.internal executor 2): org.apache.spark.SparkException: Failed to execute user defined function(VowpalWabbitBaseModel$$Lambda$3291/1845615700: (struct<features:struct<type:tinyint,size:int,indices:array,values:array>>) => double) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_subExpr_1$(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:35) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:907) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:359) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373) at org.apache.spark.rdd.RDD.iterator(RDD.scala:337) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:131) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:750) Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.predictInternal(VowpalWabbitBaseModel.scala:103) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.predictInternal$(VowpalWabbitBaseModel.scala:96) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:65) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.$anonfun$transformImplInternal$2(VowpalWabbitBaseModel.scala:56) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.$anonfun$transformImplInternal$2$adapted(VowpalWabbitBaseModel.scala:56) ... 18 more

Driver stacktrace: at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2470) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2419) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2418) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2418) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1125) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1125) at scala.Option.foreach(Option.scala:407) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1125) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2684) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2626) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2615) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:914) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2241) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2262) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2281) at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:494) at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:447) at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:47) at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3760) at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2763) at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3751) at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:107) at org.apache.spark.sql.execution.SQLExecution$.withTracker(SQLExecution.scala:232) at org.apache.spark.sql.execution.SQLExecution$.executeQuery$1(SQLExecution.scala:110) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:135) at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:107) at org.apache.spark.sql.execution.SQLExecution$.withTracker(SQLExecution.scala:232) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:135) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:253) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:134) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:68) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3749) at org.apache.spark.sql.Dataset.head(Dataset.scala:2763) at org.apache.spark.sql.Dataset.take(Dataset.scala:2970) at org.apache.spark.sql.Dataset.getRows(Dataset.scala:303) at org.apache.spark.sql.Dataset.showString(Dataset.scala:340) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:750) Caused by: org.apache.spark.SparkException: Failed to execute user defined function(VowpalWabbitBaseModel$$Lambda$3291/1845615700: (struct<features:struct<type:tinyint,size:int,indices:array,values:array>>) => double) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_subExpr_1$(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:35) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:907) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:359) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373) at org.apache.spark.rdd.RDD.iterator(RDD.scala:337) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:131) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more Caused by: java.lang.ClassCastException: java.lang.Integer cannot be cast to org.vowpalwabbit.spark.prediction.ScalarPrediction at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.predictInternal(VowpalWabbitBaseModel.scala:103) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.predictInternal$(VowpalWabbitBaseModel.scala:96) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitClassificationModel.predictInternal(VowpalWabbitClassifier.scala:65) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.$anonfun$transformImplInternal$2(VowpalWabbitBaseModel.scala:56) at com.microsoft.azure.synapse.ml.vw.VowpalWabbitBaseModel.$anonfun$transformImplInternal$2$adapted(VowpalWabbitBaseModel.scala:56) ... 18 more

bgereke avatar May 26 '22 22:05 bgereke

It's not implemented yet: https://github.com/microsoft/SynapseML/blob/9d16166314ef42767e1c50b9c831c163050c15d3/vw/src/main/scala/com/microsoft/azure/synapse/ml/vw/VowpalWabbitClassifier.scala#L64

let me follow-up w/ Jack. is there a minimal dataset we can use to repro?

eisber avatar May 27 '22 09:05 eisber

Modified from the adult census example:

import pyspark.sql.functions as f
from synapse.ml.vw import VowpalWabbitFeaturizer, VowpalWabbitClassifier

df = spark.read.format("csv")\
               .option("header", True)\
               .option("inferSchema", True)\
               .load("path/to/adult.csv") #download from kaggle

data = df.toDF(*(c.replace('.', '-') for c in df.columns))\
         .select(["education", "marital-status", "hours-per-week", "income"])

#create binary label
data = data.withColumn("label", f.when(f.col("income").contains("<"), 0.0).otherwise(1.0)).repartition(1)

#create random multiclass label
num_classes = 5
data = data.withColumn("random_label", f.round(f.rand()*(num_classes-1), 0))

vw_featurizer = VowpalWabbitFeaturizer(inputCols=["education", "marital-status", "hours-per-week"],
                                       outputCol="features")
data = vw_featurizer.transform(data)

#fit binary classifier 
binary_args = "--loss_function=logistic --quiet --holdout_off"
binary_model = VowpalWabbitClassifier(featuresCol="features",
                                  labelCol="label",
                                  args=binary_args,
                                  numPasses=10)
binary_model.fit(data).transform(data).show(10, False) #works like a charm

#fit multiclass classifier
multi_args = f"--loss_function=logistic --quiet --holdout_off --oaa={num_classes}"
multi_model = VowpalWabbitClassifier(featuresCol="features",
                                  labelCol="random_label",
                                  args=multi_args,
                                  numPasses=10)
multi_model.fit(data).transform(data).show(10, False) #gives the stack trace                 

bgereke avatar May 27 '22 19:05 bgereke

waiting for VW PR to be approved and then we can look into merging this for SynapseML

eisber avatar May 30 '22 16:05 eisber

Thanks for the effort on the PR! I look forward to test driving once it gets merged :)

bgereke avatar Jun 01 '22 21:06 bgereke