mlr
mlr copied to clipboard
Can't predict on randomForest when test set contains NA's in features
I don't know if this is a bug in some sort or if I'm overlooking something, but this baffled @ja-thomas and me a bit this morning. Consider a simple case where you have a missing value somewhere in your test set like in this example:
lrn.rf = makeLearner("classif.randomForest")
mod = train(lrn.rf, iris.task)
test.df = getTaskData(iris.task)
test.df[1L, 1L] = NA
mlr then throws an error when you try to predict on this set, randomForest's predict method doesn't though:
# throws error: row names contain missing values
predict(mod, newdata = test.df)
# if I'm directly using the predict method from randomForest it works
predict(mod$learner.model, test.df)
I tried printing out .newdata
in predictLearner.classif.randomForest
to see if we do sth unwanted with the data.frame before sending it to the learner's predict method but row names / str etc. looks fine.
Any ideas?
This sounds like a bug. Could you make a unit test that reproduces it please?
before you produce a unit test: please really show the produced output in such cases
This sounds like a bug. Could you make a unit test that reproduces it please?
why are you asking him that? he already kinda posted that test....?
the problem is more: what happens here?
the real problem seems to be: mlr claims that the rf does not handle missing values. so if you would put the NA into the task data, neither training NOR prediction would work, and you would get a meaningful error message.
but here the prediction data frame is directly handled. and the underlying RF just creates an NA for the prediction. this also links to the issue https://github.com/mlr-org/mlr/issues/1499
we should probably create the task description internally and sanity check it.
This problem does not only occur with the random forest. I ran every available learner on this problem and here is what i found out:
test.df = getTaskData(iris.task)
test.df[1L, 1L] = NA
learners = listLearners(obj = "classif", properties = "multiclass")$class
res = lapply(learners, function(x) {
mod = train(makeLearner(x), iris.task)
tryCatch(predict(mod, newdata = test.df), error = function(e) conditionMessage(e))
}
)
So the problem here is, that only the test set contains missing values, and some learners support that, others don't. Here is a small summary of the results:
learner | predict |
---|---|
classif.bdk | setosa |
classif.boosting | setosa |
classif.C50 | setosa |
classif.cforest | setosa |
classif.ctree | setosa |
classif.cvglmnet | NA |
classif.dbnDNN | NA |
classif.earth | invalid subscript type 'list' |
classif.evtree | setosa |
classif.extraTrees | setosa |
classif.featureless | versicolor |
classif.fnn | no missing values are allowed |
classif.gausspr | arguments imply differing number of rows: 150, 149 |
classif.gbm | setosa |
classif.geoDA | NA |
classif.glmnet | setosa |
classif.h2o.deeplearning | setosa |
classif.h2o.gbm | setosa |
classif.h2o.randomForest | setosa |
classif.IBk | setosa |
classif.J48 | setosa |
classif.JRip | setosa |
classif.kknn | arguments imply differing number of rows: 150, 149 |
classif.knn | no missing values are allowed |
classif.ksvm | arguments imply differing number of rows: 150, 149 |
classif.lda | NA |
classif.LiblineaRL1L2SVC | NA/NaN/Inf in foreign function call (arg 2) |
classif.LiblineaRL1LogReg | NA/NaN/Inf in foreign function call (arg 2) |
classif.LiblineaRL2L1SVC | NA/NaN/Inf in foreign function call (arg 2) |
classif.LiblineaRL2LogReg | NA/NaN/Inf in foreign function call (arg 2) |
classif.LiblineaRL2SVC | NA/NaN/Inf in foreign function call (arg 2) |
classif.LiblineaRMultiClassSVC | NA/NaN/Inf in foreign function call (arg 2) |
classif.linDA | NA |
classif.lssvm | arguments imply differing number of rows: 150, 149 |
classif.lvq1 | no missing values are allowed |
classif.mda | arguments imply differing number of rows: 150, 149 |
classif.mlp | missing values in 'x' |
classif.multinom | NA |
classif.naiveBayes | setosa |
classif.nnet | NA |
classif.nnTrain | NA |
classif.OneR | setosa |
classif.PART | setosa |
classif.qda | NA |
classif.quaDA | NA |
classif.randomForest | row names contain missing values |
classif.randomForestSRC | setosa |
classif.ranger | Missing data in columns: Sepal.Length. |
classif.rda | virginica |
classif.rFerns | NAs in predictors. |
classif.rknn | no missing values are allowed |
classif.rpart | setosa |
classif.RRF | row names contain missing values |
classif.rrlda | invalid subscript type 'list' |
classif.saeDNN | NA |
classif.sda | NA |
classif.sparseLDA | NA |
classif.svm | arguments imply differing number of rows: 150, 149 |
classif.xgboost | setosa |
classif.xyf | setosa |
Well we already have the missing
learner property for that, so we just need to check that in predict()
.
The bug for classif.randomForest
in particular seems to be that it puts names on its predictions, and gives the NA
prediction a NA
name, which cbind
in makePrediction
trips over. Having a
if (is.matrix(p))
colnames(p) = NULL
else
names(p) = NULL
anywhere in between would fix this. There are still other learners, however, that throw errors when the prediction data set contains NA
s.
@mb706 Did #2099 solve this particular issue for randomForest?
I am still getting the same error as mentioned in the OP.
Have to inspect.
library(mlr)
#> Loading required package: ParamHelpers
lrn.rf = makeLearner("classif.randomForest")
mod = train(lrn.rf, iris.task)
test.df = getTaskData(iris.task)
test.df[1L, 1L] = NA
# throws error: row names contain missing values
predict(mod, newdata = test.df)
#> Error in (function (..., row.names = NULL, check.rows = FALSE, check.names = TRUE, : row names contain missing values
# if I'm directly using the predict method from randomForest it works
predict(mod$learner.model, test.df)
#> <NA> 2 3 4 5 6 7
#> <NA> setosa setosa setosa setosa setosa setosa
#> 8 9 10 11 12 13 14
#> setosa setosa setosa setosa setosa setosa setosa
#> 15 16 17 18 19 20 21
#> setosa setosa setosa setosa setosa setosa setosa
#> 22 23 24 25 26 27 28
#> setosa setosa setosa setosa setosa setosa setosa
#> 29 30 31 32 33 34 35
#> setosa setosa setosa setosa setosa setosa setosa
#> 36 37 38 39 40 41 42
#> setosa setosa setosa setosa setosa setosa setosa
#> 43 44 45 46 47 48 49
#> setosa setosa setosa setosa setosa setosa setosa
#> 50 51 52 53 54 55 56
#> setosa versicolor versicolor versicolor versicolor versicolor versicolor
#> 57 58 59 60 61 62 63
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor
#> 64 65 66 67 68 69 70
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor
#> 71 72 73 74 75 76 77
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor
#> 78 79 80 81 82 83 84
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor
#> 85 86 87 88 89 90 91
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor
#> 92 93 94 95 96 97 98
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor
#> 99 100 101 102 103 104 105
#> versicolor versicolor virginica virginica virginica virginica virginica
#> 106 107 108 109 110 111 112
#> virginica virginica virginica virginica virginica virginica virginica
#> 113 114 115 116 117 118 119
#> virginica virginica virginica virginica virginica virginica virginica
#> 120 121 122 123 124 125 126
#> virginica virginica virginica virginica virginica virginica virginica
#> 127 128 129 130 131 132 133
#> virginica virginica virginica virginica virginica virginica virginica
#> 134 135 136 137 138 139 140
#> virginica virginica virginica virginica virginica virginica virginica
#> 141 142 143 144 145 146 147
#> virginica virginica virginica virginica virginica virginica virginica
#> 148 149 150
#> virginica virginica virginica
#> Levels: setosa versicolor virginica
Created on 2019-12-31 by the reprex package (v0.3.0)
Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 3.6.2 Patched (2019-12-12 r77564)
#> os macOS Mojave 10.14.6
#> system x86_64, darwin15.6.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/Berlin
#> date 2019-12-31
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib
#> assertthat 0.2.1 2019-03-21 [1]
#> backports 1.1.5 2019-10-02 [1]
#> BBmisc 1.11 2017-03-10 [1]
#> callr 3.4.0 2019-12-09 [1]
#> checkmate 1.9.4 2019-07-04 [1]
#> cli 2.0.0.9000 2019-12-21 [1]
#> colorspace 1.4-1 2019-03-18 [1]
#> crayon 1.3.4 2017-09-16 [1]
#> data.table 1.12.8 2019-12-09 [1]
#> desc 1.2.0 2018-05-01 [1]
#> devtools 2.2.1 2019-09-24 [1]
#> digest 0.6.23 2019-11-23 [1]
#> dplyr 0.8.3 2019-07-04 [1]
#> ellipsis 0.3.0 2019-09-20 [1]
#> evaluate 0.14 2019-05-28 [1]
#> fansi 0.4.0 2018-10-05 [1]
#> fastmatch 1.1-0 2017-01-28 [1]
#> fs 1.3.1 2019-05-06 [1]
#> ggplot2 3.2.1 2019-08-10 [1]
#> glue 1.3.1 2019-03-12 [1]
#> gtable 0.3.0 2019-03-25 [1]
#> highr 0.8 2019-03-20 [1]
#> htmltools 0.4.0 2019-10-04 [1]
#> knitr 1.26 2019-11-12 [1]
#> lattice 0.20-38 2018-11-04 [2]
#> lazyeval 0.2.2 2019-03-15 [1]
#> lifecycle 0.1.0 2019-08-01 [1]
#> magrittr 1.5 2014-11-22 [1]
#> Matrix 1.2-18 2019-11-27 [2]
#> memoise 1.1.0 2017-04-21 [1]
#> mlr * 2.16.0.9000 2019-12-11 [1]
#> munsell 0.5.0 2018-06-12 [1]
#> parallelMap 1.4.0.9000 2019-12-19 [1]
#> ParamHelpers * 1.13.0.9000 2019-12-11 [1]
#> pillar 1.4.3 2019-12-20 [1]
#> pkgbuild 1.0.6 2019-10-09 [1]
#> pkgconfig 2.0.3 2019-09-22 [1]
#> pkgload 1.0.2 2018-10-29 [1]
#> prettyunits 1.0.2 2015-07-13 [1]
#> processx 3.4.1 2019-07-18 [1]
#> ps 1.3.0 2018-12-21 [1]
#> purrr 0.3.3 2019-10-18 [1]
#> R6 2.4.1 2019-11-12 [1]
#> randomForest 4.6-14 2018-03-25 [1]
#> Rcpp 1.0.3 2019-11-08 [1]
#> remotes 2.1.0 2019-06-24 [1]
#> rlang 0.4.2.9000 2019-12-25 [1]
#> rmarkdown 2.0 2019-12-12 [1]
#> rprojroot 1.3-2 2018-01-03 [1]
#> scales 1.1.0 2019-11-18 [1]
#> sessioninfo 1.1.1 2018-11-05 [1]
#> stringi 1.4.3 2019-03-12 [1]
#> stringr 1.4.0 2019-02-10 [1]
#> survival 3.1-8 2019-12-03 [2]
#> testthat 2.3.1 2019-12-01 [1]
#> tibble 2.1.3 2019-06-06 [1]
#> tidyselect 0.2.5 2018-10-11 [1]
#> usethis 1.5.1.9000 2019-12-14 [1]
#> withr 2.1.2 2018-03-15 [1]
#> xfun 0.11 2019-11-12 [1]
#> XML 3.98-1.20 2019-06-06 [1]
#> yaml 2.2.0 2018-07-25 [1]
#> source
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> Github (r-lib/cli@0293ae7)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> local
#> CRAN (R 3.6.1)
#> local
#> Github (berndbischl/ParamHelpers@c2d989c)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> Github (r-lib/rlang@ce4f717)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.2)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> Github (r-lib/usethis@b2e894e)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.1)
#> CRAN (R 3.6.0)
#>
#> [1] /Users/pjs/Library/R/3.6/library
#> [2] /Library/Frameworks/R.framework/Versions/3.6/Resources/library