libTLDA icon indicating copy to clipboard operation
libTLDA copied to clipboard

Logistic Discriminator Cross Val Predict Outputs Classes (Not Probabilities)

Open kharrigian opened this issue 4 years ago • 2 comments

The cross_val_predict method used for domain discrimination in the instance weighting class outputs classes (not probabilities) by default. It appears the intended behavior requires passing method="predict_proba" to the cross_val_predict method. The output will be a 2d-array of class probabilities when using this argument and will thus require column-indexing to extract the appropriate probabilities.

I'm not sure if this was always the case with sklearn or if it was updated after the original domain discrimination implementation. Either way, this additional argument should rectify the issue.

kharrigian avatar Jul 24 '20 16:07 kharrigian

@wmkouw Are you accepting pull requests? If so, I think this could be a quick change from:

https://github.com/wmkouw/libTLDA/blob/0c66ec2327d191b88fca0803a7c74b0bb05afd42/libtlda/iw.py#L256

To: preds = cross_val_predict(lr, XZ, y[:, 0], cv=5, method="predict_proba")[:,1]

rchew avatar Jul 13 '21 20:07 rchew

Hi @rchew, yeah sure. I had big plans for this toolbox, but have changed my research focus in the last 2/3 years and I don't think I have time for further development.

wmkouw avatar Aug 02 '21 08:08 wmkouw