MCMCChains.jl icon indicating copy to clipboard operation
MCMCChains.jl copied to clipboard

Prior/posterior predictive check plots

Open PaulinaMartin96 opened this issue 2 years ago • 2 comments

"ppcplot" function was added for plotting prior/posterior predictive checks for one or more dependent variables. As args this function receives yobs_data, the observed data for dependet variables (a vector or matrix), and ypred_data , the posterior/prior predictive results (Chains object). It plots the observed data, a sample of predictions and the predictions mean.

As kwargs, this function receives:

  • yvar_name (vector of Symbol) which contains the name of the dependent variables to be plotted,
  • plot_type which can take :density , :cumulative, and :histogram as values,
  • predictive_check for plot titles and can be :prior or :posterior (default value is :posterior)
  • n_samples which established the number o samples to be plotted (default value is 50, but when plotting it is redefined as the minimum between 50 and sample size in ypred_data).

For more than one dependet variable in a single model, yvar_name must be provided and the order in which names variables appear must be the same as in the observed data matrix. This was done in order to separate predictions for every dependent variable, because predict does not return predictions ordered by variable.

The following is a working example for a model with one dependent variable

using Turing, StatsBase, Statistics, MCMCChains, StatsPlots

@model function linear_reg(x, y, σ = 0.1) 
            β ~ Normal(1, 0.5) 
  
            for i ∈ eachindex(y) 
                y[i] ~ Normal(β * x[i], σ) 
            end 
        end; 
  
σ = 0.1; f(x) = 2 * x + 0.1 * randn();   
Δ = 0.01; xs_train = 0:Δ:10; ys_train = f.(xs_train);   
xs_test = [10 + i*Δ for i in 1:100]; ys_test = f.(xs_test); 
m_train = linear_reg(xs_train, ys_train, σ);

#Prior predictive check
chain_lin_reg = sample(m_train, Prior(), 200);   
m_test_prior = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);   
predictions_prior = predict(m_test_prior, chain_lin_reg) 
ppcplot(ys_test, predictions_prior, yvar_name = [:y_var], predictive_check = :prior, plot_type = :density )

image

And for posterior predictive check

#Posterior predictive check  
chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);   
m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);   
predictions_posterior = predict(m_test, chain_lin_reg) 
ppcplot(ys_test, predictions_posterior)

Plot_type = :density image

Plot_type = :cumulative

ppcplot(ys_test, predictions_posterior, n_samples = 20, predictive_check = :posterior, plot_type = :cumulative, size = (900, 600))

image

Plot_type = :histogram image

Aditionally, this is a working example for a model with two dependent variables

@model function linear_reg(x, y, z, σ = 0.1) 
            β ~ Normal(0, 1)
            γ ~ Normal(0, 1)
  
            for i ∈ eachindex(y) 
                y[i] ~ Normal(β * x[i], σ)
                z[i] ~ Normal(γ * x[i], σ)    
            end 
        end; 
  
σ = 0.1; f(x) = 2 * x + 0.1 * randn(); g(x) = 4 * x + 0.4 * randn();  
Δ = 0.01; xs_train = 0:Δ:10; ys_train = f.(xs_train); zs_train = g.(xs_train); 
xs_test = [10 + i*Δ for i in 1:100]; ys_test = f.(xs_test); zs_test = g.(xs_test);  
m_train = linear_reg(xs_train, ys_train, zs_train, σ); 
  
chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200); 
  
m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), Vector{Union{Missing, Float64}}(undef, length(zs_test)), σ); 
  
predictions = predict(m_test, chain_lin_reg)

var_test = hcat(ys_test, zs_test)
ppcplot(var_test, predictions, n_samples = 100, yvar_name = [:y, :z], predictive_check = :posterior, plot_type = :density, size = (900, 400))

image

ppcplot(var_test, predictions, n_samples = 30, yvar_name = [:y, :z], predictive_check = :posterior, plot_type = :cumulative, size = (900, 400))

image

var_name = [:y, :z]
ppcplot(var_test, predictions, yvar_name = var_name, n_samples = 10, predictive_check = :posterior, plot_type = :histogram, size = (900, 600))

image

PaulinaMartin96 avatar Jul 19 '21 06:07 PaulinaMartin96

For this PR, should the version be 4.16.0 (after #316 ) or 5.1.0 (after #310 )?

PaulinaMartin96 avatar Jul 23 '21 17:07 PaulinaMartin96

Probably 4.16.0 since #310 is a bigger thing and probably won't have too much effect here.

cpfiffer avatar Jul 24 '21 17:07 cpfiffer