rstan icon indicating copy to clipboard operation
rstan copied to clipboard

Introducing `data.table` dependency to significantly improve `read_stan_csv` read times

Open SimonCMills opened this issue 3 years ago • 6 comments

read_stan_csv becomes very slow once models have thousands of parameters, with the bottleneck occurring at the read stage (see e.g. https://github.com/paul-buerkner/brms/issues/1331). A very simple solution would be to alter the code block running lines 143:161 in stan_csv.R section with code that instead reads in the csv via data.table::fread(). I've worked on this a bit with @jsocolar and found that the readLines() approach that rstan currently uses is faster up to around ~1000 parameters, at which point it becomes increasingly slow relative to fread(). Across the range of csv sizes for which the readLines() approach is faster, fread() times are however also fast (<1-2 seconds for a single model), which I think is probably a trivial slowdown for most purposes? Conversely, by the time you are reading a csv with around 5000 parameters you are saving >10 seconds by using fread (~30% saving relative to readLines()) with the proportional and absolute savings continuing to widen with increasing number of parameters.

Would you be willing to consider introducing a data.table dependency in order to achieve speedups?

Code checking timings and equivalence of two methods below (comparing just the initial code block which is then used downstream in the rest of the function). Checking equivalence of two methods are complicated by occasional minor floating point differences between data.table and base R (discussed e.g. here). New code owes heavily to the cmdstanr implementation of this function.

# comparing header section of rstan::read_stan_csv (in read_initial), and 
# updated version that makes use of data.table::fread (in read_initial2). 
# Check for equivalence and compare timings.

library(cmdstanr)

model_code <- "
parameters {
  vector[10] x;
}
model {
  x ~ std_normal();
}
"
stan_file <- cmdstanr::write_stan_file(code = model_code)
mod <- cmdstan_model(stan_file)
fit <- mod$sample(parallel_chains = 4)
fit_warmup <- mod$sample(parallel_chains = 4, save_warmup = T)
# note: sampling with dense metric becomes slow when number of parameters become 
# large (e.g. around 500 is slow on my machine). 
fit_dense_warmup <- mod$sample(parallel_chains = 4, metric = "dense_e", save_warmup = T)
fit_optimize <- mod$optimize()
fit_variational <- mod$variational()

test_suite <- list(
  fit$output_files()[[1]], 
  fit$output_files(), 
  fit_warmup$output_files(),
  fit_dense_warmup$output_files(),
  fit_optimize$output_files(),
  fit_variational$output_files()
)

# code from read_stan_csv, taken verbatim from lines 127:161 of stan_csv.R 
read_initial <- function(csvfiles) {
  # Read the csv files saved from Stan (or RStan) to a stanfit object
  # Args:
  #   csvfiles: csv files fitted for the same model; each file contains 
  #     the sample of one chain 
  if (length(csvfiles) < 1) 
    stop("csvfiles does not contain any CSV file name")
  
  g_skip <- 10 # g_skip is never used anywhere.
  
  ss_lst <- vector("list", length(csvfiles))
  cs_lst2 <- vector("list", length(csvfiles))
  
  for (i in seq_along(csvfiles)) {
    f = csvfiles[i]    
    header <- rstan:::read_csv_header(f)
    lineno <- attr(header, 'lineno') 
    vnames <- strsplit(header, ",")[[1]] 
    iter.count <- attr(header,"iter.count") 
    variable.count <- length(vnames) 
    
    lines = readLines(f)
    comment_lines = grep("^#", lines)
    comments = lines[comment_lines]
    con = textConnection(lines[-comment_lines])
    on.exit(close(con))
    df = read.csv(con, colClasses = "numeric")
    cs_lst2[[i]] <- rstan:::parse_stancsv_comments(comments)
    if("output_samples" %in% names(cs_lst2[[i]])) 
      df <- df[-1,] # remove the means 
    ss_lst[[i]] <- df
  } 
  list(csvfiles = csvfiles, ss_lst = ss_lst, cs_lst2 = cs_lst2,
       f = f) 
}

# updated version to use fread, code inherited from cmdstanr

# repair path helper function 
# verbatim from cmdstanr:::repair_path
repair_path <- function(path) {
  if (!length(path) || !is.character(path)) {
    return(path)
  }
  path <- path.expand(path)
  path <- gsub("\\\\", "/", path)
  path <- gsub("//", "/", path)
  if (endsWith(path, "/")) {
    path <- substr(path, 1, nchar(path) - 1)
  }
  path
}

read_initial2 <- function(csvfiles) {
  if (length(csvfiles) < 1) 
    stop("csvfiles does not contain any CSV file name")
  
  ss_lst <- vector("list", length(csvfiles))
  cs_lst2 <- vector("list", length(csvfiles))
  
  if (length(csvfiles) < 1) 
    stop("csvfiles does not contain any CSV file name")
  
  ss_lst <- vector("list", length(csvfiles))
  cs_lst2 <- vector("list", length(csvfiles))
  
  for (i in seq_along(csvfiles)) {
    f = csvfiles[i]    
    
    # get non-comment component, "df" (code from cmdstanr:::read_cmdstan_csv)
    if (isTRUE(.Platform$OS.type == "windows")) {
      grep_path <- repair_path(Sys.which("grep.exe"))
      fread_cmd <- paste0(grep_path, " -v '^#' --color=never '", 
                          f, "'")
    } else {
      fread_cmd <- paste0("grep -v '^#' --color=never '", 
                          f, "'")
    }
    
    df <- data.table::fread(cmd = fread_cmd, data.table = FALSE, 
                            colClasses = "numeric")
    
    # get comments (code from cmdstanr:::read_csv_metadata)
    if (isTRUE(.Platform$OS.type == "windows")) {
      grep_path <- repair_path(Sys.which("grep.exe"))
      fread_cmd <- paste0(grep_path, " '^[#a-zA-Z]' --color=never '", 
                          f, "'")
    } else {
      fread_cmd <- paste0("grep '^[#a-zA-Z]' --color=never '", 
                          f, "'")
    }
    
    suppressWarnings(metadata <- data.table::fread(cmd = fread_cmd, 
                                                   colClasses = "character", 
                                                   stringsAsFactors = FALSE, 
                                                   fill = TRUE, sep = "", 
                                                   header = FALSE, 
                                                   data.table=FALSE))
    
    # minor reformatting
    metadata2 <- metadata[,1][grepl("#", metadata[,1])]
    metadata3 <- sub("#$", "# ", metadata2)
    
    cs_lst2[[i]] <- rstan:::parse_stancsv_comments(metadata3)
    if("output_samples" %in% names(cs_lst2[[i]])) df <- df[-1,] # remove the means 
    ss_lst[[i]] <- df
  } 
  
  list(csvfiles = csvfiles, ss_lst = ss_lst, cs_lst2 = cs_lst2,
       f = f)
}

time1 <- system.time(ts1 <- lapply(test_suite, read_initial))
time2 <- system.time(ts2 <- lapply(test_suite, read_initial2))

# often identical, but not always, due to floating point differences in non-comment
# component
identical(ts1, ts2)
all.equal(ts1, ts2, tolerance = 1e-15)

# check non-comment component
for(i in 1:length(ts1)) {
  print(identical(ts1[[i]][-2], ts2[[i]][-2]))
}

for(i in 1:length(ts1)) {
  print(paste0("exactly 0 (", i, "): ",  all(ts1[[i]][2]$ss_lst[[1]] - 
                                        ts2[[i]][2]$ss_lst[[1]] == 0)))
}

for(i in 1:length(ts1)) {
print(paste0("almost 0 (", i, "): ",  all(ts1[[i]][2]$ss_lst[[1]] - 
                                            ts2[[i]][2]$ss_lst[[1]] < 10e-15)))
}

SimonCMills avatar Jun 20 '22 11:06 SimonCMills

Just to chime in that by the time you get up to 1e+5 parameters, the readLines approach takes hours at least, while the fread approach remains fast. Tagging @bgoodri and @hsbadr: if the data.table dependency is a problem in rstan, then we will re-implement the fast version of read_stan_csv in brms, where it gets used to generate brmsfit objects when using the cmdstanr backend.

jsocolar avatar Jun 20 '22 19:06 jsocolar

Is this still under development or has there been other solutions somewhere which I have missed? This is quite a big issue with large models, where sometimes the actual sampling is faster than returning the object to R.

helske avatar Apr 04 '24 05:04 helske

No action on the rstan side, but a function that achieves exactly this has been implemented in brms here https://github.com/paul-buerkner/brms/pull/1400

jsocolar avatar Apr 04 '24 14:04 jsocolar