libTLDA
libTLDA copied to clipboard
Logistic Discriminator Cross Val Predict Outputs Classes (Not Probabilities)
The cross_val_predict
method used for domain discrimination in the instance weighting class outputs classes (not probabilities) by default. It appears the intended behavior requires passing method="predict_proba"
to the cross_val_predict
method. The output will be a 2d-array of class probabilities when using this argument and will thus require column-indexing to extract the appropriate probabilities.
I'm not sure if this was always the case with sklearn
or if it was updated after the original domain discrimination implementation. Either way, this additional argument should rectify the issue.
@wmkouw Are you accepting pull requests? If so, I think this could be a quick change from:
https://github.com/wmkouw/libTLDA/blob/0c66ec2327d191b88fca0803a7c74b0bb05afd42/libtlda/iw.py#L256
To:
preds = cross_val_predict(lr, XZ, y[:, 0], cv=5, method="predict_proba")[:,1]
Hi @rchew, yeah sure. I had big plans for this toolbox, but have changed my research focus in the last 2/3 years and I don't think I have time for further development.