dbarts
dbarts copied to clipboard
Constant values for non-parametric/BART component predictions with rbart_vi
Thank you for the great work on dbarts!
I've seen this behavior on multiple datasets and it seems counterintuitive; but I might be overlooking something:
## example from function rbart_vi
f <- function(x) {
10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma <- 1.0
n <- 100
x <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y <- rnorm(n, Ey, sigma)
n.g <- 10
g <- sample(n.g, length(y), replace = TRUE)
sigma.b <- 1.5
b <- rnorm(n.g, 0, sigma.b)
y <- y + b[g]
df <- as.data.frame(x)
colnames(df) <- paste0("x_", seq_len(ncol(x)))
df$y <- y
df$g <- g
## low numbers to reduce run time (works fine)
set.seed(42)
rbartFit <- rbart_vi(y ~ . - g, df, group.by = g,
n.samples = 40L, n.burn = 10L, n.thin = 2L,
n.chains = 1L,
n.trees = 25L, n.threads = 1L
head(predict(rbartFit, newdata = df[1:5,], group.by = df[1:5,]$g, type = "ppd"))
## [,1] [,2] [,3] [,4] [,5]
##[1,] 3.314817 17.32898 20.01549 1.172588 19.14767
##[2,] 9.574205 15.93575 14.94165 5.810090 21.45006
##[3,] 9.927755 14.90057 16.09273 3.274698 19.30742
##[4,] 8.075448 13.32316 17.34530 1.973367 18.65846
##[5,] 9.501832 13.13619 19.36455 3.259061 16.43262
##[6,] 6.951968 16.95147 16.73997 5.225885 20.07958
head(predict(rbartFit, newdata = df[1:5,], group.by = df[1:5,]$g, type = "bart"))
## [,1] [,2] [,3] [,4] [,5]
##[1,] 6.045431 16.89998 16.67287 4.972953 20.60489
##[2,] 9.460054 15.46783 15.19634 5.326231 20.54520
##[3,] 8.894971 16.68499 14.30837 6.777138 19.86926
##[4,] 9.668461 17.78030 14.42151 6.032235 20.18210
##[5,] 10.771178 16.67549 16.71494 5.109395 18.24331
##[6,] 7.975217 18.83930 15.40911 5.932053 18.58619
## default rbart_vi settings yield constants for BART-component predictions
set.seed(42)
rbartFit <- rbart_vi(y ~ . - g, df, group.by = g)
head(predict(rbartFit, newdata = df[1:5,], group.by = df[1:5,]$g, type = "ppd"))
## [,1] [,2] [,3] [,4] [,5]
##[1,] 17.51067 11.59959 17.29335 14.27244 17.42347
##[2,] 14.94680 12.06890 15.79795 13.94135 16.37092
##[3,] 16.14532 11.71485 17.15503 13.54932 16.84710
##[4,] 16.27910 11.84871 16.80194 13.02856 16.48098
##[5,] 15.98320 12.77454 16.56062 13.61612 16.37043
##[6,] 15.77031 12.38151 16.85415 14.70622 14.87979
> head(predict(rbartFit, newdata = df[1:5,], group.by = df[1:5,]$g, type = "bart"))
## [,1] [,2] [,3] [,4] [,5]
##[1,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[2,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[3,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[4,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[5,] 14.94328 14.94328 14.94328 14.94328 14.94328
##[6,] 14.94328 14.94328 14.94328 14.94328 14.94328
```