StatsPlots.jl icon indicating copy to clipboard operation
StatsPlots.jl copied to clipboard

Alternative Corner Plot?

Open farr opened this issue 4 years ago • 4 comments

I'm afraid I really don't like the format of cornerplot when plotting the outputs of MCMC simulations. Normally I don't care so much about the correlation coefficient of my samples (and I certainly don't want to color the scatterplots by correlation coefficient); ideally, I would like to show both some estimate of the 2D density in the off-diagonal grid squares and also show the samples (as a sanity check of the density estimate).

The code below implements something like what I prefer---is there a way to slip this recipe into StatsPlots.jl, perhaps with a name that is better than alternativecornerplot ;)? If others are interested, I'd be happy to file a formal pull request---my personal preference would be just to replace the existing cornerplot recipe, but of course I understand that there may be reluctance to make such a dramatic change. Any suggestions for good names?

The code:

using KernelDensity
using Plots
using StatsPlots

@userplot CornerPlot
@recipe function f(cp::CornerPlot)
    m = cp.args[1]
        
    nl = get(plotattributes, :levels, 10)
    N = size(m, 1)
    
    labs = pop!(plotattributes, :label, ["x$i" for i=1:N])
    if labs!=[""] && length(labs)!=N
        error("Number of labels not identical to number of datasets")
    end
        
    legend := false
    layout := (N,N)
    
    for i in 1:N
        # Do the diagonals
        @series begin
            subplot := i + (i-1)*N
            seriestype := :density
            xlims := (minimum(m[i]), maximum(m[i]))
            ylims := (0, Inf)
            xguide := labs[i]
            x := m[i]            
        end
    end
    
    for i in 1:N
        for j in 1:(i-1)
            # Do the kdeplots
            k = kde((m[j], m[i]))
            dv = vec(k.density)
            inds = reverse(sortperm(dv))
            cd = cumsum(dv[inds])
            C = cd[end]
            
            levels = []
            for i in 1:nl
                f = i/(nl+1)
                cf = f*C
                ind = searchsortedfirst(cd, cf)
                push!(levels, dv[inds[ind]])
            end
            levels = reverse(levels)

            @series begin
                seriestype := :contour
                subplot := (i-1)*N + j
                seriescolor --> :viridis
                x := k.x
                y := k.x
                z := permutedims(k.density)
                levels := levels
                xlims := (minimum(m[j]), maximum(m[j]))
                ylims := (minimum(m[i]), maximum(m[i]))
                xguide := labs[j]
                yguide := labs[i]
                k.x, k.y, permutedims(k.density)
            end
        end
    end
    
    for i in 1:N
        for j in (i+1):N
            # Do the scatterplots
            @series begin
                seriestype := scatter
                subplot := (i-1)*N + j
                x := m[j]
                y := m[i]
                markersize --> 0.1
                xlims := (minimum(m[j]), maximum(m[j]))
                ylims := (minimum(m[i]), maximum(m[i]))
                xguide := labs[j]
                yguide := labs[i]
                m[j], m[i]
            end
        end
    end
end

Usage: obtain an MCMC chain, and then

@df trace cornerplot([:a, :b, :c], label=[L"a", L"b", L"c"], size=(1000, 1000))

will produce

corner

farr avatar Oct 31 '20 03:10 farr

I am coming from seaborn and am also accustomed to using these sorts of plots (via PairGrid).

PythonNut avatar Nov 21 '20 00:11 PythonNut

That's very neat, thanks! With the existing corner plot I'm also irritated by the line and the coloring.

Maybe single corner recipe with more options instead of multiple alternative recipes is the better the way to go.

A 2d-histogram could be an alternative for the density, but that's only a personal preference.

scheidan avatar Dec 04 '20 14:12 scheidan

One of the nice things about seaborn's PairGrid is that you can customize the upper triangle, diagonal, and lower triangle plots to be whatever you like (e.g. histogram, scatter, kde, etc.).

PythonNut avatar Dec 05 '20 02:12 PythonNut

Both ArviZPlots.jl and CornerPlot.jl do something similar to this. It probably would be good to offer this alternative layout using keyword arguments.

sethaxen avatar Jul 08 '21 00:07 sethaxen