XBART
XBART copied to clipboard
Prediction is zero which is incorrect
Hi Jingyu:
From this program, all posterior samples are zero including sigma.
library(XBART)
f = function(x)
10*sin(pi*x[ , 1]*x[ , 2]) + 5*x[ , 3]*x[ , 4]^2 + 20*x[ , 5]
N = 10000
sigma = 1.0 ##y = f(x) + sigma*z where z~N(0, 1)
P = 25 ##number of covariates
B=8
V = diag(P)
V[5, 6] = 0.8
V[6, 5] = 0.8
L <- chol(V)
set.seed(12)
x.train=matrix(rnorm(N*P), N, P) %*% L
dimnames(x.train)[[2]] <- paste0('x', 1:P)
y.train=(f(x.train)+sigma*rnorm(N))
H=20
x=seq(-3, 3, length.out=H+1)[-(H+1)]
x.test=matrix(0, nrow=H, ncol=P)
x.test[ , 5]=x
##(L=0.25*(log(N)^(log(log(N)))))
post = XBART.CLT(cbind(y.train), x.train, x.test,
num_trees=50, num_sweeps=40,
burnin=15)
post$yhat.test=post$yhats_test
##post$yhats_test=NULL
post$yhat.test.mean=apply(post$yhat.test, 1, mean)
post$yhat.test.025=apply(post$yhat.test, 1, quantile, probs=0.025)
post$yhat.test.975=apply(post$yhat.test, 1, quantile, probs=0.975)
plot(x, f(x.test), col='blue', type='l', ylab='f(x)')
lines(x, post$yhat.test.mean)
dev.copy2pdf(file='bigdata.pdf')
The R output is as follows.
R version 3.5.2 (2018-12-20) -- "Eggshell Igloo"
Copyright (C) 2018 The R Foundation for Statistical Computing
Platform: x86_64-pc-linux-gnu (64-bit)
R is free software and comes with ABSOLUTELY NO WARRANTY.
You are welcome to redistribute it under certain conditions.
Type 'license()' or 'licence()' for distribution details.
Natural language support but running in an English locale
R is a collaborative project with many contributors.
Type 'contributors()' for more information and
'citation()' on how to cite R or R packages in publications.
Type 'demo()' for some demos, 'help()' for on-line help, or
'help.start()' for an HTML browser interface to help.
Type 'q()' to quit R.
> setwd('/home/rsparapa/git/XBART/demo')
options(width=78, length=99999)
> library(XBART)
> f = function(x)
+ 10*sin(pi*x[ , 1]*x[ , 2]) + 5*x[ , 3]*x[ , 4]^2 + 20*x[ , 5]
> N = 10000
> sigma = 1.0 ##y = f(x) + sigma*z where z~N(0, 1)
> P = 25 ##number of covariates
> B=8
> V = diag(P)
> V[5, 6] = 0.8
> V[6, 5] = 0.8
> L <- chol(V)
> set.seed(12)
> x.train=matrix(rnorm(N*P), N, P) %*% L
> dimnames(x.train)[[2]] <- paste0('x', 1:P)
> y.train=(f(x.train)+sigma*rnorm(N))
> H=20
> x=seq(-3, 3, length.out=H+1)[-(H+1)]
> x.test=matrix(0, nrow=H, ncol=P)
> x.test[ , 5]=x
> ##(L=0.25*(log(N)^(log(log(N)))))
>
> post = XBART.CLT(cbind(y.train), x.train, x.test,
+ num_trees=50, num_sweeps=40,
+ burnin=15)
tau = 1/num_trees, default value.
mtry = p, use all variables.
> post$yhat.test=post$yhats_test
> ##post$yhats_test=NULL
> post$yhat.test.mean=apply(post$yhat.test, 1, mean)
> post$yhat.test.025=apply(post$yhat.test, 1, quantile, probs=0.025)
> post$yhat.test.975=apply(post$yhat.test, 1, quantile, probs=0.975)
> plot(x, f(x.test), col='blue', type='l', ylab='f(x)')
> lines(x, post$yhat.test.mean)
> dev.copy2pdf(file='bigdata.pdf')
X11cairo
2
> library(help=XBART)
Information on package ‘XBART’
Description:
Package: XBART
Type: Package
Title: XBART: Accelerated Bayesian Additive Regression
Trees
Version: 0.2
Date: 2019-09-5
Author: Jingyu He, Saar Yalov, P. Richard Hahn, Lee
Reeves
Maintainer: Jingyu He <[email protected]>
Description: A highly efficient prediction algorithm based on
trees.
License: Apache License (== 2.0)
Imports: Rcpp (>= 0.12.13)
LinkingTo: Rcpp, RcppArmadillo
NeedsCompilation: yes
Packaged: 2020-05-20 21:08:35 UTC; rsparapa
Built: R 3.5.2; x86_64-pc-linux-gnu; 2020-05-20 21:08:48
UTC; unix
Index:
XBART XBART: Accelerated Bayesian Additive Regression
Trees
XBART-package XBART: Accelerated Bayesian Additive Regression
Trees
XBART.CLT XBART: Accelerated Bayesian Additive Regression
Trees
XBART.Probit XBART: Accelerated Bayesian Additive Regression
Trees
> sessionInfo()
R version 3.5.2 (2018-12-20)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: CentOS Linux 7 (Core)
Matrix products: default
BLAS: /usr/lib64/libblas.so.3.4.2
LAPACK: /usr/lib64/liblapack.so.3.4.2
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] XBART_0.2
loaded via a namespace (and not attached):
[1] compiler_3.5.2 tools_3.5.2 Rcpp_1.0.4
>
Are you planning to respond to these at some point? It is almost 6 months later.