AxisKeys.jl icon indicating copy to clipboard operation
AxisKeys.jl copied to clipboard

Support for PDMats

Open bencottier opened this issue 3 years ago • 0 comments

I would like KeyedArrays to play nice as parameters of Distributions.

This already works in most cases, but there are problems with PDMats.

There is a full interface to define for new subtypes of AbstractPDMat, but I don't know if all of that is warranted.

Constructor example

julia> using AxisKeys, Distributions, PDMats

julia> pdmat = PDMat([1.0 0.0; 0.0 1.0])
2×2 PDMat{Float64,Array{Float64,2}}:
 1.0  0.0
 0.0  1.0

julia> PDMat(KeyedArray([1.0 0.0; 0.0 1.0], ([:x1, :x2], [:y1, :y2])))
ERROR: MethodError: no method matching PDMat(::KeyedArray{Float64,2,Array{Float64,2},Tuple{Array{Symbol,1},Array{Symbol,1}}})
Closest candidates are:
  PDMat(::AbstractArray{T,2} where T, ::LinearAlgebra.Cholesky{T,S}) where {T, S} at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:12
  PDMat(::LinearAlgebra.Symmetric) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:20
  PDMat(::Array{T,2} where T) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:19
  ...
Stacktrace:
 [1] top-level scope at REPL[10]:1

This prevents passing a KeyedArray as a covariance matrix in a Distribution.

julia> d = MvNormal(
           KeyedArray([1.0, 2.0], [:a, :b]),
           KeyedArray([1.0 0.0; 0.0 1.0], ([:a, :b], [:a, :b]))
       )
ERROR: MethodError: no method matching PDMats.PDMat(::KeyedArray{Float64,2,Array{Float64,2},Tuple{Array{Symbol,1},Array{Symbol,1}}})
Closest candidates are:
  PDMats.PDMat(::AbstractArray{T,2} where T, ::LinearAlgebra.Cholesky{T,S}) where {T, S} at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:12
  PDMats.PDMat(::LinearAlgebra.Symmetric) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:20
  PDMats.PDMat(::Array{T,2} where T) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:19
  ...
Stacktrace:
 [1] MvNormal(::KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}, ::KeyedArray{Float64,2,Array{Float64,2},Tuple{Array{Symbol,1},Array{Symbol,1}}}) at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/multivariate/mvnormal.jl:211
 [2] top-level scope at REPL[92]:1

\ operator example

Another case is the \ operator, which is used by e.g. Distributions._logpdf:

julia> pdmat = PDMat([1.0 0.0; 0.0 1.0]);

julia> pdmat \ [1.0, 2.0]
2-element Array{Float64,1}:
 1.0
 2.0

julia> pdmat \ KeyedArray([1.0, 2.0], [:x, :y])
ERROR: MethodError: \(::PDMat{Float64,Array{Float64,2}}, ::KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}) is ambiguous. Candidates:
  \(a::PDMat, x::Union{AbstractArray{T,1}, AbstractArray{T,2}} where T) in PDMats at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:50
  \(x::AbstractArray{T,2} where T, y::KeyedArray{T,1,AT,RT} where RT where AT where T<:Number) in AxisKeys at /Users/bencottier/.julia/packages/AxisKeys/1jgJz/src/functions.jl:303
Possible fix, define
  \(::PDMat, ::KeyedArray{T,1,AT,RT} where RT where AT where T<:Number)
Stacktrace:
 [1] top-level scope at REPL[12]:1

julia> d = MvNormal(KeyedArray([1.0, 2.0], [:a, :b]), [1.0 0.0; 0.0 1.0])
MvNormal{Float64,PDMats.PDMat{Float64,Array{Float64,2}},KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}}(
dim: 2
μ: [1.0, 2.0]
Σ: [1.0 0.0; 0.0 1.0]
)

julia> logpdf(d, [1.0, 2.0])
ERROR: MethodError: \(::LinearAlgebra.LowerTriangular{Float64,Array{Float64,2}}, ::KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}) is ambiguous. Candidates:
  \(A::Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UpperTriangular}, B::AbstractArray{T,1} where T) in LinearAlgebra at /Applications/Julia-1.5.app/Contents/Resources/julia/share/julia/stdlib/v1.5/LinearAlgebra/src/triangular.jl:2050
  \(x::AbstractArray{T,2} where T, y::KeyedArray{T,1,AT,RT} where RT where AT where T<:Number) in AxisKeys at /Users/bencottier/.julia/packages/AxisKeys/1jgJz/src/functions.jl:303
Possible fix, define
  \(::Union{LinearAlgebra.LowerTriangular, LinearAlgebra.UpperTriangular}, ::KeyedArray{T,1,AT,RT} where RT where AT where T<:Number)
Stacktrace:
 [1] invquad(::PDMat{Float64,Array{Float64,2}}, ::KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}) at /Users/bencottier/.julia/packages/PDMats/Rw2Hf/src/pdmat.jl:79
 [2] sqmahal(::MvNormal{Float64,PDMat{Float64,Array{Float64,2}},KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}}, ::Array{Float64,1}) at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/multivariate/mvnormal.jl:266
 [3] _logpdf(::MvNormal{Float64,PDMat{Float64,Array{Float64,2}},KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}}, ::Array{Float64,1}) at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/multivariate/mvnormal.jl:127
 [4] logpdf(::MvNormal{Float64,PDMat{Float64,Array{Float64,2}},KeyedArray{Float64,1,Array{Float64,1},Base.RefValue{Array{Symbol,1}}}}, ::Array{Float64,1}) at /Users/bencottier/.julia/packages/Distributions/cNe2C/src/multivariates.jl:201
 [5] top-level scope at REPL[18]:1

bencottier avatar Mar 12 '21 11:03 bencottier