AxisKeys.jl
AxisKeys.jl copied to clipboard
Support for PDMats
I would like KeyedArray
s to play nice as parameters of Distribution
s.
This already works in most cases, but there are problems with PDMats.
There is a full interface to define for new subtypes of AbstractPDMat
, but I don't know if all of that is warranted.
Constructor example
julia> using AxisKeys, Distributions, PDMats
julia> pdmat = PDMat([1.0 0.0; 0.0 1.0])
2×2 PDMat{Float64,Array{Float64,2}}:
1.0 0.0
0.0 1.0
julia> PDMat(KeyedArray([1.0 0.0; 0.0 1.0], ([:x1, :x2], [:y1, :y2])))
ERROR: MethodError: no method matching PDMat(::KeyedArray{Float64,2,Array{Float64,2},Tuple{Array{Symbol,1},Array{Symbol,1}}})
Closest candidates are:
PDMat(::AbstractArray{T,2} where T, ::LinearAlgebra.Cholesky{T,S}) where {T, S} at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:12
PDMat(::LinearAlgebra.Symmetric) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:20
PDMat(::Array{T,2} where T) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:19
...
Stacktrace:
[1] top-level scope at REPL[10]:1
This prevents passing a KeyedArray
as a covariance matrix in a Distribution
.
julia> d = MvNormal(
KeyedArray([1.0, 2.0], [:a, :b]),
KeyedArray([1.0 0.0; 0.0 1.0], ([:a, :b], [:a, :b]))
)
ERROR: MethodError: no method matching PDMats.PDMat(::KeyedArray{Float64,2,Array{Float64,2},Tuple{Array{Symbol,1},Array{Symbol,1}}})
Closest candidates are:
PDMats.PDMat(::AbstractArray{T,2} where T, ::LinearAlgebra.Cholesky{T,S}) where {T, S} at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:12
PDMats.PDMat(::LinearAlgebra.Symmetric) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:20
PDMats.PDMat(::Array{T,2} where T) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:19
...
Stacktrace:
[1] MvNormal(::KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}, ::KeyedArray{Float64,2,Array{Float64,2},Tuple{Array{Symbol,1},Array{Symbol,1}}}) at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/multivariate/mvnormal.jl:211
[2] top-level scope at REPL[92]:1
\
operator example
Another case is the \
operator, which is used by e.g. Distributions._logpdf
:
julia> pdmat = PDMat([1.0 0.0; 0.0 1.0]);
julia> pdmat \ [1.0, 2.0]
2-element Array{Float64,1}:
1.0
2.0
julia> pdmat \ KeyedArray([1.0, 2.0], [:x, :y])
ERROR: MethodError: \(::PDMat{Float64,Array{Float64,2}}, ::KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}) is ambiguous. Candidates:
\(a::PDMat, x::Union{AbstractArray{T,1}, AbstractArray{T,2}} where T) in PDMats at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:50
\(x::AbstractArray{T,2} where T, y::KeyedArray{T,1,AT,RT} where RT where AT where T<:Number) in AxisKeys at /Users/bencottier/.julia/packages/AxisKeys/1jgJz/src/functions.jl:303
Possible fix, define
\(::PDMat, ::KeyedArray{T,1,AT,RT} where RT where AT where T<:Number)
Stacktrace:
[1] top-level scope at REPL[12]:1
julia> d = MvNormal(KeyedArray([1.0, 2.0], [:a, :b]), [1.0 0.0; 0.0 1.0])
MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}}(
dim: 2
μ: [1.0, 2.0]
Σ: [1.0 0.0; 0.0 1.0]
)
julia> logpdf(d, [1.0, 2.0])
ERROR: MethodError: \(::LinearAlgebra.LowerTriangular{Float64,Array{Float64,2}}, ::KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}) is ambiguous. Candidates:
\(A::Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UpperTriangular}, B::AbstractArray{T,1} where T) in LinearAlgebra at /Applications/Julia-1.5.app/Contents/Resources/julia/share/julia/stdlib/v1.5/LinearAlgebra/src/triangular.jl:2050
\(x::AbstractArray{T,2} where T, y::KeyedArray{T,1,AT,RT} where RT where AT where T<:Number) in AxisKeys at /Users/bencottier/.julia/packages/AxisKeys/1jgJz/src/functions.jl:303
Possible fix, define
\(::Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UpperTriangular}, ::KeyedArray{T,1,AT,RT} where RT where AT where T<:Number)
Stacktrace:
[1] invquad(::PDMat{Float64,Array{Float64,2}}, ::KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:79
[2] sqmahal(::MvNormal{Float64,PDMat{Float64,Array{Float64,2}},KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}}, ::Array{Float64,1}) at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/multivariate/mvnormal.jl:266
[3] _logpdf(::MvNormal{Float64,PDMat{Float64,Array{Float64,2}},KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}}, ::Array{Float64,1}) at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/multivariate/mvnormal.jl:127
[4] logpdf(::MvNormal{Float64,PDMat{Float64,Array{Float64,2}},KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}}, ::Array{Float64,1}) at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/multivariates.jl:201
[5] top-level scope at REPL[18]:1