SynapseML
SynapseML copied to clipboard
Classifier ... doesn't extend from HasRawPredictionCol.
Describe the bug No sure it's a bug and not my own error. Using pyspark.ml OneVsRest with a lightGBM binary classifier I get the following error
AssertionError: Classifier <class 'mmlspark.lightgbm.LightGBMClassifier.LightGBMClassifier'> doesn't extend from HasRawPredictionCol.
To Reproduce
from pyspark.ml.classification import OneVsRest
from mmlspark.lightgbm import LightGBMClassifier
lgb = LightGBMClassifier(featuresCol = feat_set, labelCol=lbl, objective='binary')
ovr_lgb = OneVsRest(classifier=lgb_sid, featuresCol=feat_set, labelCol=lbl, predictionCol='prediction')
mod_ovr_lgb = ovr_lgb.fit(train)
Expected behavior Thought it would fit multiple binary classifiers using a one-vs-rest strategy
Info (please complete the following information):
- MMLSpark Version: 2.11:1.0.0-rc3
- Spark Version : 2.4.3,
- Spark Platform : AWS EMR
** Stacktrace**
lgb = LightGBMClassifier(featuresCol = feat_set, labelCol=lbl, objective='binary')...
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
~/codebase/pnp_datascience_segmentation_lifestyle/03_modeling/01_initials_rolling_n/03_models_initials_cdt5.py in
439 lgb = LightGBMClassifier(featuresCol = feat_set, labelCol=lbl, objective='binary')
440 ovr_lgb = OneVsRest(classifier=lgb_sid, featuresCol=feat_set, labelCol=lbl, predictionCol='prediction')
----> 441 mod_ovr_lgb = ovr_lgb.fit(train)
/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/base.py in fit(self, dataset, params)
130 return self.copy(params)._fit(dataset)
131 else:
--> 132 return self._fit(dataset)
133 else:
134 raise ValueError("Params must be either a param map or a list/tuple of param maps, "
/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/classification.py in _fit(self, dataset)
1800 classifier = self.getClassifier()
1801 assert isinstance(classifier, HasRawPredictionCol),\
-> 1802 "Classifier %s doesn't extend from HasRawPredictionCol." % type(classifier)
1803
1804 numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
AssertionError: Classifier <class 'mmlspark.lightgbm.LightGBMClassifier.LightGBMClassifier'> doesn't extend from HasRawPredictionCol.
Foud a reference to the Spark GBT throwing the same error, because its implemented as an estimator - perhaps thats the same here.
@philmassie it does extend ProbabilisticClassifier in the scala code: https://github.com/Azure/mmlspark/blob/master/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMClassifier.scala#L27 which extends hasRawPredictionCol: https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/ProbabilisticClassifier.html
I think the problem is the pyspark wrapper (which calls the scala code) doesn't extend it. The pyspark wrapper does extend estimator but I guess it doesn't extend the pyspark equivalent of ProbabilisticClassifier which needs to be fixed?
In any case you can use LightGBMClassifier for multiclass data, you just need to change the objective to be multiclass or multiclassova:
https://github.com/microsoft/LightGBM/blob/master/docs/Parameters.rst#objective
the pyspark wrapper is auto-generated so maybe this is something that needs to be fixed in the autogen code... some of the autogen'ed wrapper is overloaded here: https://github.com/Azure/mmlspark/blob/master/src/main/python/mmlspark/lightgbm/LightGBMClassificationModel.py but it's only related to the model
Thanks again @imatiach-msft . Yes I was using the multiclass function of LightGBM before and its amazing. My reason for wondering about the oneVsRest approach was because moving from rc1 to rc3 I was getting very different models. I wondered if it was some default that had changed perhaps, but nevertheless I was scrambling to try different approaches. I still dont understand the difference in results and I'll try replicate it sometime on a public data set since I dont reckon my employer would be happy with me if I shared the training data here :) When I get to that I'll open another issue.
Thanks for the explanation about the extends, to be honest my Scala is pretty weak so its hard to understand the implications of the extends bit, but I'll get there eventually. Thanks again to the whole team for a marvelous library.
@imatiach-msft must I close this?
@philmassie no please keep it open it seems like this is indeed an issue that needs to be fixed in the auto-generated pyspark wrapper