Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Wishart priors resulting in `PosDefException: matrix is not positive definite; Cholesky factorization failed`

Open ipozdeev opened this issue 10 months ago • 7 comments

julia v1.10.2 Turing v0.30.7 Distributions v0.25.107

Using Wishart priors results in several functions throwing the above error, which does not make sense to me, e.g. in this MWE maximizing:

using Turing, MCMCChains
using Statistics, LinearAlgebra, PDMats
using Optim

# parameter of the Wishart prior
A = Matrix{Float64}(I, 3, 3);
isposdef(A)  # true
ishermitian(A)  # true

@model function demo(x)
    _A ~ Wishart(5, A);
    _x_mu = sum(_A);
    return x ~ Normal(_x_mu, 1);
end

# condition model on single obs
demo_model = demo(1.0);

map_estimate = optimize(demo_model, MAP());  # error
chain = sample(model, HMC(0.05, 10), 1000);  # error

chain = sample(model, MH(), 1000);  # no error

MAP() throws an error, as does sampling with HMC and NUTS, but not with MH.

ipozdeev avatar Apr 10 '24 06:04 ipozdeev

Likely related to https://github.com/TuringLang/Bijectors.jl/pull/313

yebai avatar May 30 '24 12:05 yebai

cc @sethaxen @torfjelde

yebai avatar May 30 '24 12:05 yebai

The above example now works with the most recent Bijectors.jl release. A small step size is useful for stable optimisation or HMC sampling:

julia> chain = sample(demo_model, HMC(0.01, 10), 10000); 
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00

I am not sure how to specify step size for optimisation. @mhauru is that possible?

yebai avatar Jun 05 '24 10:06 yebai

These numerical errors can be explicitly catched and used to inform the inference backend to reject the proposal.



julia> @model function demo(x)
           try 
                 _A ~ Wishart(5, A);
                 _x_mu = sum(_A);
                 x ~ Normal(_x_mu, 1); 
          catch e; 
                 if e isa PosDefException
                        Turing.@addlogprob! -Inf; 
                 end 
          end
       end


julia> chain = sample(demo(1), NUTS(), 1000)
┌ Info: Found initial step size
└   ϵ = 0.0125                                                                                    |  ETA: N/A
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (1000×21×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 0.19 seconds
Compute duration  = 0.19 seconds
parameters        = _A[1, 1], _A[2, 1], _A[3, 1], _A[1, 2], _A[2, 2], _A[3, 2], _A[1, 3], _A[2, 3], _A[3, 3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

    _A[1, 1]    3.6504    2.3821    0.0935   546.6195   372.3050    1.0003     2907.5507
    _A[2, 1]   -1.4789    1.7435    0.0766   565.4885   588.1340    1.0066     3007.9174
    _A[3, 1]   -1.5504    1.7307    0.1053   254.4582   458.9476    1.0133     1353.5010
    _A[1, 2]   -1.4789    1.7435    0.0766   565.4885   588.1340    1.0066     3007.9174
    _A[2, 2]    3.5547    2.2049    0.0896   549.1068   560.1294    1.0012     2920.7810
    _A[3, 2]   -1.4334    1.6267    0.0781   466.9718   495.1435    0.9991     2483.8925
    _A[1, 3]   -1.5504    1.7307    0.1053   254.4582   458.9476    1.0133     1353.5010
    _A[2, 3]   -1.4334    1.6267    0.0781   466.9718   495.1435    0.9991     2483.8925
    _A[3, 3]    3.5786    2.2005    0.1018   556.3600   493.4034    1.0026     2959.3615

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

    _A[1, 1]    0.5643    1.8723    3.1353    4.8230    9.5504
    _A[2, 1]   -5.6961   -2.3031   -1.2180   -0.3144    1.1457
    _A[3, 1]   -5.8425   -2.3965   -1.2686   -0.4025    1.0603
    _A[1, 2]   -5.6961   -2.3031   -1.2180   -0.3144    1.1457
    _A[2, 2]    0.5728    1.9093    3.1736    4.6654    8.8305
    _A[3, 2]   -5.1991   -2.3551   -1.2366   -0.3020    1.2692
    _A[1, 3]   -5.8425   -2.3965   -1.2686   -0.4025    1.0603
    _A[2, 3]   -5.1991   -2.3551   -1.2366   -0.3020    1.2692
    _A[3, 3]    0.7432    1.9730    3.1545    4.6346    9.2545

yebai avatar Jun 05 '24 12:06 yebai

I am not sure how to specify step size for optimisation. @mhauru is that possible?

Depends on the optimisation algorithm, but if the algorithm has a notion of a step size, then usually yes. The default algorithm is LBFGS, which first finds a direction to go in and then does a line search along that direction to figure out how far to go, so there isn't a fixed step size, but you can set an initial guess for the step size like this:

optimize(demo_model, MAP(), Optim.LBFGS(;alphaguess=0.01));

That seems to help avoiding the loss of positivity errors in this case.

mhauru avatar Jun 05 '24 14:06 mhauru

@yebai I don't think anything I did fixing the correlation bijectors would have fixed this. I'm not at a computer right now, but I imagine the problem is similar; the Jacobian computation is going through a cholesky decomposition, which is wasteful and can randomly fail due to floating point errors. The solution is to make the same fix for the covariance bijector.

The other place a PosDefException would be raised randomly is if one used the Wishart matrix as the covariance of an MvNormal. The solution there is the same as LKJ: add a WishartCholesky to Distributions.jl and a VecCovBijector to Bijectors.jl. Same goes for InverseWishart.

sethaxen avatar Jun 05 '24 14:06 sethaxen

Thanks @sethaxen, for the clarification. For now, users can be referred to https://github.com/TuringLang/Turing.jl/issues/2188#issuecomment-2149755180 before numerically more stable alternatives are implemented. We should also update docs to include some guides on how to use try-catch block to handle numerical exceptions.

yebai avatar Jun 05 '24 14:06 yebai