tfcausalimpact
tfcausalimpact copied to clipboard
Question: Is it possible to extract inclusion probabilities of predictors from trained model?
Hi, thank you all your efforts in developing this library
Previously, we used Google's initial R library, CausalImpact
, for causal inference. Now, we're in search of a good substitute for it in Python, and this library seems like an excellent option. In the initial R library, there is functionality to extract posterior inclusion probabilities of predictors from the trained model. We use these probabilities as additional descriptive statistics to fine-tune the model and exclude certain predictors from the control group.
I'm curious if there's an approach to extract these posterior inclusion probabilities of predictors in tfcausalimpact
. Your guidance on this matter would be immensely helpful.
Just R code to demonstrate what exactly I mean:
library(CausalImpact)
set.seed(1)
x1 <- 100 + arima.sim(model = list(ar = 0.88), n = 100)
x2 <- arima.sim(model = list(ar = 0.22), n = 100) + rnorm(100)
x3 <- rnorm(100) + rnorm(100, mean = 80, sd = 10)
y <- 1.2 * (x1+x2+x3)/3 + rnorm(100)
y[71:100] <- y[71:100] + 60
data <- cbind(y, x1, x2, x3)
pre.period <- c(1, 70)
post.period <- c(71, 100)
impact <- CausalImpact(data, pre.period, post.period)
plot(impact$model$bsts.model, "coefficients")
Hi @hnasko ,
It is possible to extract the posterior of the weights. Please refer to the getting_started notebook section 2.5 where this is discussed:
If you used the default Hamiltonian method then this code should retrieve the posterior samples averaged:
tf.reduce_mean(ci.model.components_by_name['SparseLinearRegression/'].params_to_weights(
ci.model_samples['SparseLinearRegression/_global_scale_variance'],
ci.model_samples['SparseLinearRegression/_global_scale_noncentered'],
ci.model_samples['SparseLinearRegression/_local_scale_variances'],
ci.model_samples['SparseLinearRegression/_local_scales_noncentered'],
ci.model_samples['SparseLinearRegression/_weights_noncentered'],
), axis=0)
Let me know if this solves for you.