rms
rms copied to clipboard
predict() fails with newdata when input is special data.frame
ols()
will allow to fit to data.frame
's with extra classes, but predict()
(predictrms()
) does not seem to drop these extra classes first, and so can in some cases result in errors when semantics for subsetting has been changed. readr package now adds the subclass spec_tbl_df
. This special class is explained in this post.
Here's an example.
> #load packages
> suppressPackageStartupMessages(library(rms))
> suppressPackageStartupMessages(library(readr))
> suppressPackageStartupMessages(library(tidyverse))
> #my data
> write_tsv(iris, "iris.tsv")
> iris2 = read_tsv("iris.tsv")
Parsed with column specification:
cols(
Sepal.Length = col_double(),
Sepal.Width = col_double(),
Petal.Length = col_double(),
Petal.Width = col_double(),
Species = col_character()
)
> iris2$Species = iris2$Species %>% factor()
> #fit
> fit = ols(Sepal.Length ~ Species, data = iris2)
> fit
Linear Regression Model
ols(formula = Sepal.Length ~ Species, data = iris2)
Model Likelihood Discrimination
Ratio Test Indexes
Obs 150 LR chi2 144.63 R2 0.619
sigma0.5148 d.f. 2 R2 adj 0.614
d.f. 147 Pr(> chi2) 0.0000 g 0.708
Residuals
Min 1Q Median 3Q Max
-1.6880 -0.3285 -0.0060 0.3120 1.3120
Coef S.E. t Pr(>|t|)
Intercept 5.0060 0.0728 68.76 <0.0001
Species=versicolor 0.9300 0.1030 9.03 <0.0001
Species=virginica 1.5820 0.1030 15.37 <0.0001
> #predictions
> predict(fit) %>% head()
1 2 3 4 5 6
5.006 5.006 5.006 5.006 5.006 5.006
> predict(fit, newdata = iris2) %>% head()
Error in predictrms: Values in Species not in setosa versicolor virginica :
Error: Must use a vector in `[`, not an object of class matrix.
Call `rlang::last_error()` to see a backtrace
> #examine factor
> iris2$Species
[1] setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa
[13] setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa
[25] setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa
[37] setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa
[49] setosa setosa versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
[61] versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
[73] versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
[85] versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
[97] versicolor versicolor versicolor versicolor virginica virginica virginica virginica virginica virginica virginica virginica
[109] virginica virginica virginica virginica virginica virginica virginica virginica virginica virginica virginica virginica
[121] virginica virginica virginica virginica virginica virginica virginica virginica virginica virginica virginica virginica
[133] virginica virginica virginica virginica virginica virginica virginica virginica virginica virginica virginica virginica
[145] virginica virginica virginica virginica virginica virginica
Levels: setosa versicolor virginica
> iris2$Species %>% str
Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
> #examine class of data
> class(iris2)
[1] "spec_tbl_df" "tbl_df" "tbl" "data.frame"
> #remove extra classes
> iris3 = iris2 %>% as.data.frame()
> class(iris3)
[1] "data.frame"
> #test predictions
> predict(fit, newdata = iris3) %>% head()
1 2 3 4 5 6
5.006 5.006 5.006 5.006 5.006 5.006
So to fix, add a as.data.frame()
somewhere in predict()
.
Thanks for the thorough description and example. Does iris2
inherits(iris2, 'data.frame')
? I need logic that will tell me exactly when to run the data through as.data.frame
.
A simple option:
if (is.data.frame(x) && length(class(x)) > 1) x = as.data.frame(x)
This converts any special data frame to the regular one using whatever method defined (using default one is none specified). Works in my limited testing.
I ran into this when trying to pass tbl_df
obtained by:
grid = tidyr::crossing(
covariate1=seq(0, 11),
covariate2=na.omit(unique(train_data$covariate2))
)
where class(grid)
is c("tbl_df", "tbl", "data.frame")
. Passing as.data.frame(grid)
worked perfectly.
is.data.frame(grid) && length(class(grid)) > 1
gives TRUE
as does inherits(grid, 'data.frame')