SynapseML
SynapseML copied to clipboard
[BUG] distributed training of LightGBM breaks when number of workers decreases while training
SynapseML version
0.11.1
System information
- Language version : Python 3.8.10
- Spark Version (e.g. 3.2.3): 12.2 LTS ML (includes Apache Spark 3.3.2, Scala 2.12)
- Spark Platform (e.g. Synapse, Databricks): Databricks
- Clusters Configuration Driver: Standard_E4ds_v4 · Workers: Standard_E4ds_v4 · 1-3 workers
Describe the problem
The training pipeline broke with the following error when training LightGBMClassifier
on the compute cluster.
This bug only occurs when the number of workers decreases during the training. I am using databricks platform, and this problem only occurs when I enable autoscaling or spot instance workers (which causes scaling down workers). When I disable autoscaling and spot instance, this has never happened.
Code to reproduce issue
## copy from https://microsoft.github.io/SynapseML/docs/Explore%20Algorithms/LightGBM/Quickstart%20-%20Classification,%20Ranking,%20and%20Regression/#bankruptcy-prediction-with-lightgbm-classifier
from synapse.ml.core.platform import *
df = (
spark.read.format("csv")
.option("header", True)
.option("inferSchema", True)
.load(
"wasbs://[email protected]/company_bankruptcy_prediction_data.csv"
)
)
train, test = df.randomSplit([0.85, 0.15], seed=1)
from pyspark.ml.feature import VectorAssembler
feature_cols = df.columns[1:]
featurizer = VectorAssembler(inputCols=feature_cols, outputCol="features")
train_data = featurizer.transform(train)["Bankrupt?", "features"]
test_data = featurizer.transform(test)["Bankrupt?", "features"]
from synapse.ml.lightgbm import LightGBMClassifier
model = LightGBMClassifier(
objective="binary", featuresCol="features", labelCol="Bankrupt?", isUnbalance=True
)
## somehow, try to force workers to evict during the training
model = model.fit(train_data)
Other info / logs
org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(644, 3) finished unsuccessfully.
ExecutorLostFailure (executor 2 exited unrelated to the running tasks) Reason: Executor decommission: worker decommissioned because of kill request from HTTP endpoint (data migration disabled)
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3386)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3318)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3309)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3309)
at org.apache.spark.scheduler.DAGScheduler.handleTaskCompletion(DAGScheduler.scala:2799)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3592)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3536)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3524)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:51)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$runJob$1(DAGScheduler.scala:1182)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1170)
at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2750)
at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1070)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:445)
at org.apache.spark.rdd.RDD.collect(RDD.scala:1068)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.executePartitionTasks(LightGBMBase.scala:621)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.executeTraining(LightGBMBase.scala:598)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.trainOneDataBatch(LightGBMBase.scala:446)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.$anonfun$train$2(LightGBMBase.scala:62)
at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logVerb(SynapseMLLogging.scala:93)
at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logVerb$(SynapseMLLogging.scala:90)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier.logVerb(LightGBMClassifier.scala:27)
at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logTrain(SynapseMLLogging.scala:84)
at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logTrain$(SynapseMLLogging.scala:83)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier.logTrain(LightGBMClassifier.scala:27)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.train(LightGBMBase.scala:64)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.train$(LightGBMBase.scala:36)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier.train(LightGBMClassifier.scala:27)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier.train(LightGBMClassifier.scala:27)
at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
at py4j.Gateway.invoke(Gateway.java:306)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:195)
at py4j.ClientServerConnection.run(ClientServerConnection.java:115)
at java.lang.Thread.run(Thread.java:750)
What component(s) does this bug affect?
- [ ]
area/cognitive
: Cognitive project - [ ]
area/core
: Core project - [ ]
area/deep-learning
: DeepLearning project - [X]
area/lightgbm
: Lightgbm project - [ ]
area/opencv
: Opencv project - [ ]
area/vw
: VW project - [ ]
area/website
: Website - [ ]
area/build
: Project build system - [ ]
area/notebooks
: Samples under notebooks folder - [ ]
area/docker
: Docker usage - [ ]
area/models
: models related issue
What language(s) does this bug affect?
- [ ]
language/scala
: Scala source code - [X]
language/python
: Pyspark APIs - [ ]
language/r
: R APIs - [ ]
language/csharp
: .NET APIs - [ ]
language/new
: Proposals for new client languages
What integration(s) does this bug affect?
- [ ]
integrations/synapse
: Azure Synapse integrations - [ ]
integrations/azureml
: Azure ML integrations - [X]
integrations/databricks
: Databricks integrations
### Tasks
Hey @surfii3z :wave:! Thank you so much for reporting the issue/feature request :rotating_light:. Someone from SynapseML Team will be looking to triage this issue soon. We appreciate your patience.