Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

Bijector for MatrixNormal

Open RodrigoZepeda opened this issue 3 years ago • 2 comments

Hi! I'm trying to sample the following Turing Model that uses a MatrixNormal distribution:

using Random, Turing, Bijectors
Random.seed!(123)

#Estimate a MatrixNormal as simulated here
U  = rand(LKJ(2, 0.5))
V  = rand(LKJ(2, 0.5))
Uₐ = rand(LKJ(2, 0.5))
Vₐ = rand(LKJ(2, 0.5))
Asample = rand(MatrixNormal(zeros(Float64, 2, 2), U, V))

#Create the model
@model function estimateA(A, U, V, Uₐ, Vₐ)
    mu ~ MatrixNormal(zeros(Float64,size(A,1), size(A,2)), Uₐ, Vₐ)
    A  ~ MatrixNormal(mu, U, V)
end

#Estimate!
model  = estimateA(Asample, U, V, Uₐ, Vₐ);
chains = sample(model, NUTS(), 100);

however I get the following error:

ERROR: MethodError: no method matching bijector(::MatrixNormal{Float64, Matrix{Float64}, PDMats.PDMat{Float64, Matrix{Float64}}, PDMats.PDMat{Float64, Matrix{Float64}}})
Closest candidates are:
  bijector(::Union{Kolmogorov, BetaPrime, Chi, Chisq, Erlang, Exponential, FDist, Frechet, Gamma, InverseGamma, InverseGaussian, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull}) at ~/.julia/packages/Bijectors/vUc4m/src/transformed_distribution.jl:58
  bijector(::Union{Arcsine, Beta, Biweight, Cosine, Epanechnikov, NoncentralBeta}) at ~/.julia/packages/Bijectors/vUc4m/src/transformed_distribution.jl:69
  bijector(::Union{Levy, Pareto}) at ~/.julia/packages/Bijectors/vUc4m/src/transformed_distribution.jl:72

The same error happens with the following code:

dist = MatrixNormal(zeros(2,2), rand(LKJ(2, 0.5)), rand(LKJ(2, 0.5)))
b     = bijector(dist)

I'm relatively new to Turing so maybe my diagnosis is not correct but it seems to me that the Bijectors package is lacking a definition for the MatrixNormal.

RodrigoZepeda avatar Nov 29 '22 16:11 RodrigoZepeda

@torfjelde

ParadaCarleton avatar Dec 06 '22 04:12 ParadaCarleton

Ah, yes we're missing a definition of

bijector(d::MatrixNormal) = Identity{2}()

I'll make a PR but you can just add this overload in the meantime.

torfjelde avatar Dec 06 '22 14:12 torfjelde