MCMCChains.jl
MCMCChains.jl copied to clipboard
Prior/posterior predictive check plots
"ppcplot" function was added for plotting prior/posterior predictive checks for one or more dependent variables. As args
this function receives yobs_data
, the observed data for dependet variables (a vector or matrix), and ypred_data
, the posterior/prior predictive results (Chains
object). It plots the observed data, a sample of predictions and the predictions mean.
As kwargs, this function receives:
-
yvar_name
(vector ofSymbol
) which contains the name of the dependent variables to be plotted, -
plot_type
which can take:density
,:cumulative
, and:histogram
as values, -
predictive_check
for plot titles and can be:prior
or:posterior
(default value is:posterior
) -
n_samples
which established the number o samples to be plotted (default value is 50, but when plotting it is redefined as the minimum between 50 and sample size in ypred_data).
For more than one dependet variable in a single model, yvar_name
must be provided and the order in which names variables appear must be the same as in the observed data matrix. This was done in order to separate predictions for every dependent variable, because predict
does not return predictions ordered by variable.
The following is a working example for a model with one dependent variable
using Turing, StatsBase, Statistics, MCMCChains, StatsPlots
@model function linear_reg(x, y, σ = 0.1)
β ~ Normal(1, 0.5)
for i ∈ eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end;
σ = 0.1; f(x) = 2 * x + 0.1 * randn();
Δ = 0.01; xs_train = 0:Δ:10; ys_train = f.(xs_train);
xs_test = [10 + i*Δ for i in 1:100]; ys_test = f.(xs_test);
m_train = linear_reg(xs_train, ys_train, σ);
#Prior predictive check
chain_lin_reg = sample(m_train, Prior(), 200);
m_test_prior = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
predictions_prior = predict(m_test_prior, chain_lin_reg)
ppcplot(ys_test, predictions_prior, yvar_name = [:y_var], predictive_check = :prior, plot_type = :density )
And for posterior predictive check
#Posterior predictive check
chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
predictions_posterior = predict(m_test, chain_lin_reg)
ppcplot(ys_test, predictions_posterior)
Plot_type = :density
Plot_type = :cumulative
ppcplot(ys_test, predictions_posterior, n_samples = 20, predictive_check = :posterior, plot_type = :cumulative, size = (900, 600))
Plot_type = :histogram
Aditionally, this is a working example for a model with two dependent variables
@model function linear_reg(x, y, z, σ = 0.1)
β ~ Normal(0, 1)
γ ~ Normal(0, 1)
for i ∈ eachindex(y)
y[i] ~ Normal(β * x[i], σ)
z[i] ~ Normal(γ * x[i], σ)
end
end;
σ = 0.1; f(x) = 2 * x + 0.1 * randn(); g(x) = 4 * x + 0.4 * randn();
Δ = 0.01; xs_train = 0:Δ:10; ys_train = f.(xs_train); zs_train = g.(xs_train);
xs_test = [10 + i*Δ for i in 1:100]; ys_test = f.(xs_test); zs_test = g.(xs_test);
m_train = linear_reg(xs_train, ys_train, zs_train, σ);
chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), Vector{Union{Missing, Float64}}(undef, length(zs_test)), σ);
predictions = predict(m_test, chain_lin_reg)
var_test = hcat(ys_test, zs_test)
ppcplot(var_test, predictions, n_samples = 100, yvar_name = [:y, :z], predictive_check = :posterior, plot_type = :density, size = (900, 400))
ppcplot(var_test, predictions, n_samples = 30, yvar_name = [:y, :z], predictive_check = :posterior, plot_type = :cumulative, size = (900, 400))
var_name = [:y, :z]
ppcplot(var_test, predictions, yvar_name = var_name, n_samples = 10, predictive_check = :posterior, plot_type = :histogram, size = (900, 600))
For this PR, should the version be 4.16.0 (after #316 ) or 5.1.0 (after #310 )?
Probably 4.16.0 since #310 is a bigger thing and probably won't have too much effect here.