SynapseML
SynapseML copied to clipboard
[BUG] LightGBMRanker: `groupCol` not recognized - LightGBM sees all records in the DataFrame as part of 1 query/group
SynapseML version
com.microsoft.azure:synapseml_2.12:1.0.5
System information
Language version: Python 3.12.2, Scala 2.12 Spark Version: 3.5.2 Spark Platform: Local (Using Macbook Pro M2 w/ 12 cores 18gb RAM)
Describe the problem
I am encountering an issue with LightGBMRanker, it seems that the model does not recognize that the PySpark DataFrame I am using for training is composed of many queries/groups.
In the native version of LightGBM, there is a parameter called group where you will specify an array-like sequence that indicates the number of sample per query/group, something like [10,20,30] where the sum of this array is the total number of samples. In my case, there are 31,674 records in my dataset.
Wondering how synapseml does this under the hood given that one should only indicate the groupCol and nothing else.
As shown in the error log, it seems the LightGBM model was not given any knowledge about how the records are grouped and thus complaining about all observation being part of a same query.
Code to reproduce issue
from synapse.ml.lightgbm import LightGBMRanker
from pyspark.ml.feature import VectorAssembler
# `train` contains `query_id` which indicates how each record is grouped
train_with_vec = spark.read.parquet("my_ranking_dataset.parquet")
vec_assembler = VectorAssembler(inputCols=feature_cols, outputCol="features", handleInvalid="keep")
train_with_vec = vec_assembler.transform(train)
train_with_vec = train_with_vec.withColumn("labels", (30 * F.col('relevance')).astype('int'))
features_col = "features"
query_col = "query_id"
label_col = "labels"
lgbm_ranker = LightGBMRanker(
labelCol=label_col,
featuresCol=features_col,
groupCol=query_col, # As shown here, I indicated `query_id` as the groupCol parameter value.
predictionCol="preds",
leafPredictionCol="leafPreds",
featuresShapCol="importances",
repartitionByGroupingColumn=True,
numLeaves=32,
numIterations=200,
evalAt=[1, 3, 5],
metric="ndcg",
useBarrierExecutionMode=True,
verbosity=4,
)
lgbm_ranker.fit(
train_with_vec
.join(
train_with_vec
.select('query_id')
.distinct()
.sample(0.001),
on='query_id',
how='inner',
)
)
Other info / logs
[LightGBM] [Info] Saving data reference to binary buffer
[Stage 71:> (0 + 8) [/](https://file+.vscode-resource.vscode-cdn.net/) 8]
[LightGBM] [Info] Loaded reference dataset: 129 features, 31674 num_data
[LightGBM] [Fatal] Number of rows 31674 exceeds upper limit of 10000 for a query
24/09/20 16:53:52 WARN StreamingPartitionTask: LightGBM reached early termination on one task, stopping training on task. This message should rarely occur. Inner exception: java.lang.Exception: Booster call failed in LightGBM with error: Number of rows 31674 exceeds upper limit of 10000 for a query
[LightGBM] [Warning] Unknown parameter: max_position
[LightGBM] [Warning] Unknown parameter: max_position
[LightGBM] [Fatal] Number of rows 31674 exceeds upper limit of 10000 for a query
[LightGBM] [Warning] Unknown parameter: max_position
[LightGBM] [Fatal] Number of rows 31674 exceeds upper limit of 10000 for a query
24/09/20 16:53:53 ERROR Executor: Exception in task 0.0 in stage 71.0 (TID 2052)
java.lang.Exception: Booster call failed in LightGBM with error: Number of rows 31674 exceeds upper limit of 10000 for a query
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils$.validate(LightGBMUtils.scala:18)
at com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster.boosterHandler$lzycompute(LightGBMBooster.scala:242)
at com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster.boosterHandler(LightGBMBooster.scala:232)
at com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster.freeNativeMemory(LightGBMBooster.scala:493)
at com.microsoft.azure.synapse.ml.lightgbm.BasePartitionTask.finalizeDatasetAndTrain(BasePartitionTask.scala:263)
at com.microsoft.azure.synapse.ml.lightgbm.BasePartitionTask.mapPartitionTask(BasePartitionTask.scala:152)
at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.$anonfun$executePartitionTasks$1(LightGBMBase.scala:615)
at org.apache.spark.rdd.RDDBarrier.$anonfun$mapPartitions$2(RDDBarrier.scala:51)
at org.apache.spark.rdd.RDDBarrier.$anonfun$mapPartitions$2$adapted(RDDBarrier.scala:51)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
at org.apache.spark.scheduler.Task.run(Task.scala:141)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
...
What component(s) does this bug affect?
- [ ]
area/cognitive: Cognitive project - [ ]
area/core: Core project - [ ]
area/deep-learning: DeepLearning project - [X]
area/lightgbm: Lightgbm project - [ ]
area/opencv: Opencv project - [ ]
area/vw: VW project - [ ]
area/website: Website - [ ]
area/build: Project build system - [ ]
area/notebooks: Samples under notebooks folder - [ ]
area/docker: Docker usage - [ ]
area/models: models related issue
What language(s) does this bug affect?
- [ ]
language/scala: Scala source code - [X]
language/python: Pyspark APIs - [ ]
language/r: R APIs - [ ]
language/csharp: .NET APIs - [ ]
language/new: Proposals for new client languages
What integration(s) does this bug affect?
- [ ]
integrations/synapse: Azure Synapse integrations - [ ]
integrations/azureml: Azure ML integrations - [ ]
integrations/databricks: Databricks integrations