fabletools icon indicating copy to clipboard operation
fabletools copied to clipboard

Add rolling windows and cv

Open jeffzi opened this issue 4 years ago • 2 comments

Hi.

After reading #180 and #177, I thought I could give a shot at implementing cross-validation.

I defined common cv procedures ExpandingWindow(), SlidingWindow(), Holdout() as rolling_window objects + 2 S3 methods:

  • roll() is based on the slider package and apply an arbitrary function (identity by default) on sub-windows iteratively. It also cuts off the specified horizon. The output is a tibble with the tsibble keys + an extra list-column containing untransformed results.
  • cv() fits models on folds and return forecasts as a fable. Intermediary folds are not kept because in most cases we are only interested in forecasts to evaluate accuracy. It's also faster and more memory-efficient.

roll() can be used to compute features() on folds if we want to do timeseries classifications for example.

The window parameters .init, .size, .step, and the cutoff h can be specified in terms of calendar periods, or in terms of the number of observations if .period is NULL. The implementation relies on the warp package.

I also implemented an optional parallelism. I confirmed with microbenchmarks that it is more efficient to parallelize on the folds rather than models.

An example of usage:

library(tsibbledata)
library(fable)

ExpandingWindow(.init = 10) %>%
  roll(aus_retail, h = 5)#> # A tibble: 152 x 3
#>    State                 Industry                                      .fold    
#>    <chr>                 <chr>                                         <list>   
#>  1 Australian Capital T… Cafes, restaurants and catering services      <list [4…
#>  2 Australian Capital T… Cafes, restaurants and takeaway food services <list [4…
#>  3 Australian Capital T… Clothing retailing                            <list [4…
#>  4 Australian Capital T… Clothing, footwear and personal accessory re… <list [4…
#>  5 Australian Capital T… Department stores                             <list [4…
#>  6 Australian Capital T… Electrical and electronic goods retailing     <list [4…
#>  7 Australian Capital T… Food retailing                                <list [4…
#>  8 Australian Capital T… Footwear and other personal accessory retail… <list [4…
#>  9 Australian Capital T… Furniture, floor coverings, houseware and te… <list [4…
#> 10 Australian Capital T… Hardware, building and garden supplies retai… <list [4…
#> # … with 142 more rows
ts <- aus_retail %>%
 filter(State %in% c("Queensland", "Victoria"), Industry == "Food retailing")

models <- list(
  snaive = SNAIVE(Turnover),
  ets = TSLM(log(Turnover) ~ trend() + season())
)

suppressWarnings({
ExpandingWindow(.init = 25, .step = 1, .period = "year") %>%
  cv(ts, h = 3, !!!models)
})#> # A fable: 4,896 x 7 [1M]
#> # Key:     .fold, State, Industry, .model [136]
#>    .fold State      Industry       .model    Month      Turnover .mean
#>    <int> <chr>      <chr>          <chr>     <mth>        <dist> <dbl>
#>  1     1 Queensland Food retailing snaive 2007 Jan N(1143, 2777) 1143.
#>  2     1 Queensland Food retailing snaive 2007 Feb N(1057, 2777) 1057.
#>  3     1 Queensland Food retailing snaive 2007 Mar N(1176, 2777) 1176.
#>  4     1 Queensland Food retailing snaive 2007 Apr N(1156, 2777) 1156.
#>  5     1 Queensland Food retailing snaive 2007 May N(1163, 2777) 1163.
#>  6     1 Queensland Food retailing snaive 2007 Jun N(1158, 2777) 1158.
#>  7     1 Queensland Food retailing snaive 2007 Jul N(1220, 2777) 1220.
#>  8     1 Queensland Food retailing snaive 2007 Aug N(1251, 2777) 1251.
#>  9     1 Queensland Food retailing snaive 2007 Sep N(1224, 2777) 1224.
#> 10     1 Queensland Food retailing snaive 2007 Oct N(1261, 2777) 1261.
#> # … with 4,886 more rows

Dev version of tibble breaks forecast() and therefore cv(). It's caused by [[<-.tbl_df:

library(tibble)
df <- tibble(x = 1:3, y = 3:1)
df[["z"]] <- c("a", "b", "c")
df
#> # A tibble: 3 x 3
#>       x     y ...3 
#>   <int> <int> <chr>
#> 1     1     3 a    
#> 2     2     2 b    
#> 3     3     1 c

# works
add_column(df, z = c("a", "b", "c"))
#> # A tibble: 3 x 4
#>       x     y ...3  z    
#>   <int> <int> <chr> <chr>
#> 1     1     3 a     a    
#> 2     2     2 b     b    
#> 3     3     1 c     c

Created on 2020-03-25 by the reprex package (v0.3.0)

I did not write tests but I can work on them if you think my implementation is useful.

jeffzi avatar Mar 25 '20 16:03 jeffzi