rms icon indicating copy to clipboard operation
rms copied to clipboard

predict() fails with newdata when input is special data.frame

Open Deleetdk opened this issue 5 years ago • 3 comments

ols() will allow to fit to data.frame's with extra classes, but predict() (predictrms()) does not seem to drop these extra classes first, and so can in some cases result in errors when semantics for subsetting has been changed. readr package now adds the subclass spec_tbl_df. This special class is explained in this post.

Here's an example.

> #load packages
> suppressPackageStartupMessages(library(rms))
> suppressPackageStartupMessages(library(readr))
> suppressPackageStartupMessages(library(tidyverse))
> #my data
> write_tsv(iris, "iris.tsv")
> iris2 = read_tsv("iris.tsv")
Parsed with column specification:
cols(
  Sepal.Length = col_double(),
  Sepal.Width = col_double(),
  Petal.Length = col_double(),
  Petal.Width = col_double(),
  Species = col_character()
)
> iris2$Species = iris2$Species %>% factor()
> #fit
> fit = ols(Sepal.Length ~ Species, data = iris2)
> fit
Linear Regression Model
 
 ols(formula = Sepal.Length ~ Species, data = iris2)
 
                Model Likelihood     Discrimination    
                   Ratio Test           Indexes        
 Obs     150    LR chi2    144.63    R2       0.619    
 sigma0.5148    d.f.            2    R2 adj   0.614    
 d.f.    147    Pr(> chi2) 0.0000    g        0.708    
 
 Residuals
 
     Min      1Q  Median      3Q     Max 
 -1.6880 -0.3285 -0.0060  0.3120  1.3120 
 
 
                    Coef   S.E.   t     Pr(>|t|)
 Intercept          5.0060 0.0728 68.76 <0.0001 
 Species=versicolor 0.9300 0.1030  9.03 <0.0001 
 Species=virginica  1.5820 0.1030 15.37 <0.0001 
 
> #predictions
> predict(fit) %>% head()
    1     2     3     4     5     6 
5.006 5.006 5.006 5.006 5.006 5.006 
> predict(fit, newdata = iris2) %>% head()
Error in predictrms: Values in Species not in setosa versicolor virginica :
Error: Must use a vector in `[`, not an object of class matrix.
Call `rlang::last_error()` to see a backtrace
> #examine factor
> iris2$Species
  [1] setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa    
 [13] setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa    
 [25] setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa    
 [37] setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa    
 [49] setosa     setosa     versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
 [61] versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
 [73] versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
 [85] versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
 [97] versicolor versicolor versicolor versicolor virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
[109] virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
[121] virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
[133] virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
[145] virginica  virginica  virginica  virginica  virginica  virginica 
Levels: setosa versicolor virginica
> iris2$Species %>% str
 Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
> #examine class of data
> class(iris2)
[1] "spec_tbl_df" "tbl_df"      "tbl"         "data.frame" 
> #remove extra classes
> iris3 = iris2 %>% as.data.frame()
> class(iris3)
[1] "data.frame"
> #test predictions
> predict(fit, newdata = iris3) %>% head()
    1     2     3     4     5     6 
5.006 5.006 5.006 5.006 5.006 5.006

So to fix, add a as.data.frame() somewhere in predict().

Deleetdk avatar Mar 11 '19 23:03 Deleetdk

Thanks for the thorough description and example. Does iris2 inherits(iris2, 'data.frame')? I need logic that will tell me exactly when to run the data through as.data.frame.

harrelfe avatar Mar 12 '19 01:03 harrelfe

A simple option:

if (is.data.frame(x) && length(class(x)) > 1) x = as.data.frame(x)

This converts any special data frame to the regular one using whatever method defined (using default one is none specified). Works in my limited testing.

Deleetdk avatar Apr 05 '19 03:04 Deleetdk

I ran into this when trying to pass tbl_df obtained by:

grid = tidyr::crossing(
    covariate1=seq(0, 11),
    covariate2=na.omit(unique(train_data$covariate2))
)

where class(grid) is c("tbl_df", "tbl", "data.frame"). Passing as.data.frame(grid) worked perfectly.

is.data.frame(grid) && length(class(grid)) > 1 gives TRUE as does inherits(grid, 'data.frame')

krassowski avatar Apr 21 '22 22:04 krassowski