StatsPlots.jl icon indicating copy to clipboard operation
StatsPlots.jl copied to clipboard

Alternative Corner Plot?

Open farr opened this issue 3 years ago • 4 comments

I'm afraid I really don't like the format of cornerplot when plotting the outputs of MCMC simulations. Normally I don't care so much about the correlation coefficient of my samples (and I certainly don't want to color the scatterplots by correlation coefficient); ideally, I would like to show both some estimate of the 2D density in the off-diagonal grid squares and also show the samples (as a sanity check of the density estimate).

The code below implements something like what I prefer---is there a way to slip this recipe into StatsPlots.jl, perhaps with a name that is better than alternativecornerplot ;)? If others are interested, I'd be happy to file a formal pull request---my personal preference would be just to replace the existing cornerplot recipe, but of course I understand that there may be reluctance to make such a dramatic change. Any suggestions for good names?

The code:

using KernelDensity
using Plots
using StatsPlots

@userplot CornerPlot
@recipe function f(cp::CornerPlot)
    m = cp.args[1]
        
    nl = get(plotattributes, :levels, 10)
    N = size(m, 1)
    
    labs = pop!(plotattributes, :label, ["x$i" for i=1:N])
    if labs!=[""] && length(labs)!=N
        error("Number of labels not identical to number of datasets")
    end
        
    legend := false
    layout := (N,N)
    
    for i in 1:N
        # Do the diagonals
        @series begin
            subplot := i + (i-1)*N
            seriestype := :density
            xlims := (minimum(m[i]), maximum(m[i]))
            ylims := (0, Inf)
            xguide := labs[i]
            x := m[i]            
        end
    end
    
    for i in 1:N
        for j in 1:(i-1)
            # Do the kdeplots
            k = kde((m[j], m[i]))
            dv = vec(k.density)
            inds = reverse(sortperm(dv))
            cd = cumsum(dv[inds])
            C = cd[end]
            
            levels = []
            for i in 1:nl
                f = i/(nl+1)
                cf = f*C
                ind = searchsortedfirst(cd, cf)
                push!(levels, dv[inds[ind]])
            end
            levels = reverse(levels)

            @series begin
                seriestype := :contour
                subplot := (i-1)*N + j
                seriescolor --> :viridis
                x := k.x
                y := k.x
                z := permutedims(k.density)
                levels := levels
                xlims := (minimum(m[j]), maximum(m[j]))
                ylims := (minimum(m[i]), maximum(m[i]))
                xguide := labs[j]
                yguide := labs[i]
                k.x, k.y, permutedims(k.density)
            end
        end
    end
    
    for i in 1:N
        for j in (i+1):N
            # Do the scatterplots
            @series begin
                seriestype := scatter
                subplot := (i-1)*N + j
                x := m[j]
                y := m[i]
                markersize --> 0.1
                xlims := (minimum(m[j]), maximum(m[j]))
                ylims := (minimum(m[i]), maximum(m[i]))
                xguide := labs[j]
                yguide := labs[i]
                m[j], m[i]
            end
        end
    end
end

Usage: obtain an MCMC chain, and then

@df trace cornerplot([:a, :b, :c], label=[L"a", L"b", L"c"], size=(1000, 1000))

will produce

corner

farr avatar Oct 31 '20 03:10 farr