DeepSuperLearner
DeepSuperLearner copied to clipboard
Problem getting accuracy predictions
When I add the code to get an accuracy score, I get an error. The code I added was: y_pred = DSL_learner.predict(X_train, y_train) y_pred = numpy.argmax(y_pred,axis=1) I am using numpy 1.14 and Ubuntu 16.04.4 .
See below:
=========================== from deepSuperLearner import *
ETC = ExtraTreesClassifier() GB = GradientBoostingClassifier()
Base_learners = {'ETC':ETC, 'GB':GB}
np.random.seed(100)
DSL_learner = DeepSuperLearner(Base_learners)
DSL_learner.fit(X_train, y_train)
y_pred = DSL_learner.predict(X_train, y_train)
y_pred = numpy.argmax(y_pred,axis=1)
print('Final prediction accuracy score: [%.4f]' % accuracy_score(y_test, y_pred))
DSL_learner.get_precision_recall(X_test, y_test, show_graphs=True)
Iteration: 0 Loss: 0.6936235359636144 Weights: [0.95464725 0.04535275] Iteration: 1 Loss: 0.6922511148906573 Weights: [0.96612301 0.03387699] Iteration: 2 Loss: 0.6930091286990414 Weights: [1. 0.]
ValueError Traceback (most recent call last)
/usr/local/lib/python3.5/dist-packages/deepSuperLearner/deepSuperLearnerLib.py in predict(self, X, return_base_learners_probs) 239 X = np.hstack((X, avg_probs)) 240 --> 241 if return_base_learners_probs: 242 return avg_probs, base_learners_probs 243
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Your usage of the predict function is wrong, look at the line: "DSL_learner.predict(X_train,y_train)" predict accepts two parameters: X : numpy.array of shape [n, l] which is X_train return_base_learners_probs : boolean that indicates wheter to return base_learners predictions also.
I get a different error when I change it to
y_pred = DSL_learner.predict(X_train, return_base_learners_probs=False)
But, if you add the accuracy score yourself to the program, I won't need to do any of this.
FYI, here's the new error I get:
Iteration: 0 Loss: 0.6686762701198394 Weights: [0.51519279 0.48480721] Iteration: 1 Loss: 0.6664984982183696 Weights: [0.36564859 0.63435141] Iteration: 2 Loss: 0.6652801586203181 Weights: [0.38726519 0.61273481] Iteration: 3 Loss: 0.6674569423227245 Weights: [0.29607315 0.70392685]
ValueError Traceback (most recent call last)
/usr/local/lib/python3.5/dist-packages/sklearn/metrics/classification.py in accuracy_score(y_true, y_pred, normalize, sample_weight) 174 175 # Compute accuracy for each possible representation --> 176 y_type, y_true, y_pred = _check_targets(y_true, y_pred) 177 if y_type.startswith('multilabel'): 178 differing_labels = count_nonzero(y_true - y_pred, axis=1)
/usr/local/lib/python3.5/dist-packages/sklearn/metrics/classification.py in _check_targets(y_true, y_pred) 69 y_pred : array or indicator matrix 70 """ ---> 71 check_consistent_length(y_true, y_pred) 72 type_true = type_of_target(y_true) 73 type_pred = type_of_target(y_pred)
/usr/local/lib/python3.5/dist-packages/sklearn/utils/validation.py in check_consistent_length(*arrays) 202 if len(uniques) > 1: 203 raise ValueError("Found input variables with inconsistent numbers of" --> 204 " samples: %r" % [int(l) for l in lengths]) 205 206
ValueError: Found input variables with inconsistent numbers of samples: [3655, 14618]
It seems that the error raises from the sklearn.accuracy_score function, can you add some prints of y_test/y_pred before of calling "accuracy_score"? anyway if you are patient I'll add more metrics by the end of this week.
It is fine to wait until you add the accuracy yourself, but here's the info you asked for: print("X_train:",X_train[:5,:],"y_train: ",y_train) X_train: [[-3.75 -5. -5.75 -4.5 -0.47192051 0.525 1.52192051 0.51842254 0.3875 0.25657746 1.45192801 -1.95 -5.35192801 -2.76873311 1.23214286 5.23301882 0.59606955 0.27235409 -0.51371614 0.04349229 0.07717927 2.05 0.10416667 -0.875 3.5 -0.700625 1.389072 0.71343152 0.1188172 -1.82142857 -4.09559285 -0.96875 -1.82142857 0.5236976 -4.96428571 4.46428571 -6.53571429 0. 13.5 0.50234375 -0.50211237 0.66192953 0.67459299 0.66192953 0.67459299 0.50088974 0.5 0.64285714 0.73214286 0.91071429 0.64285714 0.67857143 -0.33333333 0.80031216 0.54092473 0.4877858 0.53178132 0.53178132 0.51649346 0.58356465 0.50063437 0.50092399 0.49971038 0.50088974 0.50153103 0.49935872 0.5011725 0.50119902 0.49997348 0.50095488 0.50067908 0.5002758 0.6000396 0.50464713 0.50105 0.49805 0.60686474 0.50276005 0.50044515 0.49902593 0.50077698 0.50499026 0.99902593 0.50500777 1.00077698 0.7438929 0.77029163 0.81019521 0.87837878 0.70175439 0.81019521 0.50012974 0.76001144 0.05844156 0.3238342 0.50184482 0.50187281 0.50092333 0.50412146 0.50251512 0.50485993 0.48099753 0.49966786 0.50251762 0.50031692] [ 1.75 2.25 1. 4. 2.73669028 1.05 -0.63669028 1.88845581 0.675 -0.53845581 5.56184444 1.5 -2.56184444 2.67524086 0.98214286 -0.71095515 0.55230452 0.2658049 1.39755759 0.10292845 1.63869719 0.96666667 0.19583333 2. 1.125 0. 0.34572115 0.83707183 -0.0827957 1.72321429 2.95745078 0.859375 1.72321429 -0.0920126 3.02678571 -0.88392857 3.67857143 4. 0. 0.50263021 1.5393 0.64302892 0.63021376 0.64302892 0.63021376 0.50062308 0.64285714 1. 0.78571429 0.5 0.85714286 0.21428571 1.28927536 1.14185545 0.62751297 0.77416508 0.5548443 0.5548443 0.76816498 0.5523335 0.50047665 0.50032482 0.50015183 0.50062308 0.50067038 0.49995269 0.50022708 0.50007375 0.50015333 0.50003887 0.49987221 0.50016666 0.55597536 0.5023003 0.5021 0.5015 0.68546944 0.50208485 0.50029976 0.50072154 0.50130026 0.50500722 1.00072154 0.505013 1.00130026 0.88708254 0.79110139 0.94911141 0.9329228 0.98958333 0.94911141 0.50002102 0.82137706 0.48837209 0.46261682 0.50080292 0.50102946 0.50038567 0.50301974 0.50301894 0.50482786 0.51461819 0.50025536 0.50109 0.50005941] [ 0.5 2.25 0.5 2.5 1.12870907 0.2 -0.72870907 0.44322261 0.8625 1.28177739 2.64499507 0.4 -1.84499507 -1.5017362 0.98214286 3.46602192 0.54467566 0.34897651 0.67293725 0.35388453 0.20257194 0.91666667 0.2125 1.
-
-1.25 0.65526214 0.71391499 0.05 0.625
1.4546618 0.03125 0.625 0.27828315 1.30357143 -0.73214286 1.64285714 2. 0. 0.50158854 1.56055 0.84424691 0.81289151 0.84424691 0.81289151 0.50002821 0.92857143 1. 0.55357143 1. 0.57142857 0.94642857 1.62992126 0.93733954 0.72436383 0.80463296 0.64332566 0.64332566 0.92862663 0.69162475 0.50036972 0.50041091 0.49995881 0.50002821 0.50036731 0.4996609 0.50060549 0.50059136 0.50001413 0.50058769 0.50046898 0.50011872 0.51550551 0.50043507 0.5004 0.5004 0.70173949 0.50094525 0.5000133 0.50018855 0.5016309 0.50500189 1.00018855 0.50501631 1.0016309 0.90231648 0.84018059 0.78379017 0.86199068 0.96153846 0.78379017 0.50005907 0.82214763 0.46153846 0.49390244 0.50033468 0.5004681 0.5001577 0.50101315 0.5049536 0.5049577 0.50245547 0.50004286 0.50033407 0.50000558] [-1.5 -2. 0. -1.25 -0.59843137 -0.375 -0.15156863 -1.03909656 -0.5625 -0.08590344 0.1305932 0.1 0.0694068 -0.57459525 -0.33928571 -0.10397617 -0.60958734 -0.42280185 -0.25131232 -0.19336561 -0.04734059 -0.55 -0.20416667 0. -0.75 0.4625 -0.77565812 -0.69742713 -0.50537634 0.13392857 -0.13010814 0.203125 0.13392857 -0.60711807 0.97321429 -1.54464286 1.39285714 -2. 0. 0.5034375 1.5285425 0.77401225 0.80006797 0.77401225 0.80006797 0.49965064 0.78571429 0.5 0.94642857 0.60714286 0.21428571 0.16071429 0.75609756 0.12628253 0.391973 0.43205774 0.40918397 0.40918397 0.6568422 0.65129139 0.49965039 0.49952367 0.50012672 0.49965064 0.49952385 0.50012679 0.4991071 0.49904649 0.5000606 0.49898125 0.49906082 0.49992042 0.60652914 0.50324215 0.49925 0.5001 0.55565462 0.50071512 0.49983037 0.50004857 0.49890962 0.50500049 1.00004857 0.5049891 0.99890962 0.71602887 0.69775816 0.67963636 0.63307921 0.71 0.67963636 0.49992103 0.72651829 0.28378378 0.08677686 0.5009178 0.50088689 0.50044575 0.50315007 0.50190164 0.50431561 0.50705823 0.50012321 0.50127 0.50008064] [-1.75 -1.75 -2.5 -2.5 0.10219551 0.2 0.29780449 0.42574037 0.6625 0.89925963 -0.51515987 0.4 1.31515987 -2.22057645 1.19642857 4.61343359 0.34977189 0.19722558 0.17737844 -0.09244131 0.04125291 0.68333333 -0.4375 0.125 0.75 1.9 0.83615039 0.6113481 -0.25349462 0.02678571 -1.56859446 1.328125 0.02678571 0.12254019 -0.91964286 1.91964286 -1.39285714 0. 0. 0.50322917 1.50613 0.60776296 0.61012824 0.60776296 0.61012824 0.50044744 0.5 0.85714286 0.53571429 0.96428571 0.85714286 0.92857143 0.54912281 0.94891468 0.57484855 0.56852963 0.53551199 0.53551199 0.54609088 0.50125592 0.50051939 0.50057909 0.4999403 0.50044744 0.50067231 0.49977513 0.50065049 0.50059062 0.50005986 0.50032897 0.50007834 0.50025063 0.59213741 0.50298832 0.5004 0.5004 0.61084906 0.50154958 0.5002223 0.50019876 0.50131972 0.50500199 1.00019876 0.5050132 1.00131972 0.78426482 0.77637239 0.89824098 0.90350129 0.78723404 0.89824098 0.50007585 0.79090468 0.25609756 0.41803279 0.50099851 0.50122291 0.50049596 0.50130811 0.50262852 0.50457515 0.5095117 0.50016607 0.50105357 0.5000555 ]] y_train: [0 0 0 ... 1 1 0]