Distributions.jl icon indicating copy to clipboard operation
Distributions.jl copied to clipboard

Extend `gradlogpdf` to MixtureModels

Open rmsrosa opened this issue 1 year ago • 7 comments

This extends gradlogpdf to MixtureModels, both univariate and multivariate, at least for those whose components have gradlogpdf implemented.

I haven't implemented the inplace and the component-wise methods yet, but this should be a good start.

I should say I am having trouble with the docs, thought. I have added docstrings and added the methods to the docs src, but they are not showing up. I am not sure what to do.

rmsrosa avatar Jan 21 '24 01:01 rmsrosa

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (c1705a3) 85.94% compared to head (53e0977) 86.16%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1827      +/-   ##
==========================================
+ Coverage   85.94%   86.16%   +0.22%     
==========================================
  Files         144      144              
  Lines        8658     8704      +46     
==========================================
+ Hits         7441     7500      +59     
+ Misses       1217     1204      -13     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Jan 21 '24 01:01 codecov-commenter

Relates #1788

sethaxen avatar Jan 21 '24 09:01 sethaxen

Relates #1788

Oh, I missed that. Sorry.

I thought about using pweights, but I was afraid it would allocate. I should benchmark both to make sure.

rmsrosa avatar Jan 21 '24 10:01 rmsrosa

But I feel this new version is not optimized. It essentially computes the PDF twice. I should unroll the loop.

rmsrosa avatar Jan 23 '24 15:01 rmsrosa

But I am not sure about type stability. Even pdf itself is not always type stable, e.g.

julia> @code_warntype pdf(MixtureModel([Normal(1, 2), Beta(2, 3)], [6/10, 4/10]), 1.0)
MethodInstance for Distributions.pdf(::MixtureModel{Univariate, Continuous, Distribution{Univariate, Continuous}, Categorical{Float64, Vector{Float64}}}, ::Float64)
  from pdf(d::UnivariateMixture, x::Real) @ Distributions ~/.julia/packages/Distributions/UaWBm/src/mixtures/mixturemodel.jl:362
Arguments
  #self#::Core.Const(Distributions.pdf)
  d::MixtureModel{Univariate, Continuous, Distribution{Univariate, Continuous}, Categorical{Float64, Vector{Float64}}}
  x::Float64
Body::Any
1 ─ %1 = Distributions._mixpdf1(d, x)::Any
└──      return %1

with

@code_warntype Distributions._mixpdf1(MixtureModel([Normal(1, 2), Beta(2, 3)], [6/10, 4/10]), 1.0)
MethodInstance for Distributions._mixpdf1(::MixtureModel{Univariate, Continuous, Distribution{Univariate, Continuous}, Categorical{Float64, Vector{Float64}}}, ::Float64)
  from _mixpdf1(d::AbstractMixtureModel, x) @ Distributions ~/.julia/packages/Distributions/UaWBm/src/mixtures/mixturemodel.jl:286
Arguments
  #self#::Core.Const(Distributions._mixpdf1)
  d::MixtureModel{Univariate, Continuous, Distribution{Univariate, Continuous}, Categorical{Float64, Vector{Float64}}}
  x::Float64
Locals
  #607::Distributions.var"#607#609"
  #606::Distributions.var"#606#608"{MixtureModel{Univariate, Continuous, Distribution{Univariate, Continuous}, Categorical{Float64, Vector{Float64}}}, Float64}
  p::Vector{Float64}
Body::Any
1 ─       (p = Distributions.probs(d))
│   %2  = Distributions.:(var"#606#608")::Core.Const(Distributions.var"#606#608")
│   %3  = Core.typeof(d)::Core.Const(MixtureModel{Univariate, Continuous, Distribution{Univariate, Continuous}, Categorical{Float64, Vector{Float64}}})
│   %4  = Core.typeof(x)::Core.Const(Float64)
│   %5  = Core.apply_type(%2, %3, %4)::Core.Const(Distributions.var"#606#608"{MixtureModel{Univariate, Continuous, Distribution{Univariate, Continuous}, Categorical{Float64, Vector{Float64}}}, Float64})
│         (#606 = %new(%5, d, x))
│   %7  = #606::Distributions.var"#606#608"{MixtureModel{Univariate, Continuous, Distribution{Univariate, Continuous}, Categorical{Float64, Vector{Float64}}}, Float64}
│         (#607 = %new(Distributions.:(var"#607#609")))
│   %9  = #607::Core.Const(Distributions.var"#607#609"())
│   %10 = Distributions.enumerate(p)::Base.Iterators.Enumerate{Vector{Float64}}
│   %11 = Base.Filter(%9, %10)::Base.Iterators.Filter{Distributions.var"#607#609", Base.Iterators.Enumerate{Vector{Float64}}}
│   %12 = Base.Generator(%7, %11)::Base.Generator{Base.Iterators.Filter{Distributions.var"#607#609", Base.Iterators.Enumerate{Vector{Float64}}}, Distributions.var"#606#608"{MixtureModel{Univariate, Continuous, Distribution{Univariate, Continuous}, Categorical{Float64, Vector{Float64}}}, Float64}}
│   %13 = Distributions.sum(%12)::Any
└──       return %13

rmsrosa avatar Jan 27 '24 12:01 rmsrosa

Hmm, I thought I had added another comment on benchmark, but maybe the connection broke and it didn't go through. Here it goes again.

I also implemented another version with the for loop detaching the indices of both prior and components.

function gradlogpdf(d::UnivariateMixture, x::Real)
    ps = probs(d)
    cs = components(d)
    ps1 = first(ps)
    cs1 = first(cs)
    pdfx1 = pdf(cs1, x)
    pdfx = ps1 * pdfx1
    glp = pdfx * gradlogpdf(cs1, x)
    if iszero(ps1) || iszero(pdfx)
        glp = zero(glp)
    end
    @inbounds for (psi, csi) in Iterators.drop(zip(ps, cs), 1)
        if !iszero(psi)
            pdfxi = pdf(csi, x)
            if !iszero(pdfxi)
                pipdfxi = psi * pdfxi
                pdfx += pipdfxi
                glp += pipdfxi * gradlogpdf(csi, x)
            end
        end
    end
    if !iszero(pdfx) # else glp is already zero
        glp /= pdfx
    end 
    return glp
end

When mixing distributions of the same type, gradlogpdf with the while loop is a tad faster, but, when mixing different types of distributions, then pdf and logpdf are type unstable, making gradlogpdf also type unstable and in this case gradlogpdf with the for loop is faster. (Same thing for the multivariate versions of the functions). (I used @btime and @benchmark to make sure.) Here are some examples.

d1 = MixtureModel([Normal(1, 2), Normal(2, 3), Normal(1.5), Normal(1.0, 0.2)], [0.3, 0.2, 0.3, 0.2])

d2 = MixtureModel([Beta(1, 2), Beta(2, 3), Beta(4, 2)], [3/10, 4/10, 3/10])

d3 = MixtureModel([Normal(1, 2), Beta(2, 3), Exponential(3/2)], [3/10, 4/10, 3/10]) # type unstable

n = 10
m = 20
d4 = MixtureModel([MvNormal(rand(n), rand(n, n) |> A -> (A + A')^2) for _ in 1:m], rand(m) |> w -> w / sum(w))

d5 = MixtureModel([MvNormal([1.0, 2.0], [0.4 0.2; 0.2 0.5]), MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.])], [0.4, 0.6]) # type unstable

We get

[ Info: d1: gradlogpdf with for loop
  75.841 ns (0 allocations: 0 bytes)
[ Info: d1: gradlogpdf with while loop
  73.626 ns (0 allocations: 0 bytes)

[ Info: d2: gradlogpdf with for loop
  295.552 ns (0 allocations: 0 bytes)
[ Info: d2: gradlogpdf with while loop
  296.826 ns (0 allocations: 0 bytes)

[ Info: d3: gradlogpdf with for loop
  686.398 ns (19 allocations: 304 bytes)
[ Info: d3: gradlogpdf with while loop
  974.368 ns (21 allocations: 368 bytes)

[ Info: d4: gradlogpdf with for loop
  33.258 μs (136 allocations: 16.62 KiB)
[ Info: d4: gradlogpdf with while loop
  33.233 μs (136 allocations: 16.62 KiB)

[ Info: d5: gradlogpdf with for loop
  3.282 μs (26 allocations: 1.42 KiB)
[ Info: d5: gradlogpdf with while loop
  3.471 μs (28 allocations: 1.53 KiB)

Ok, I think this is it. I will leave it with the while loop for your review.

rmsrosa avatar Jan 28 '24 13:01 rmsrosa

Concerning if iszero(psi) || iszero(pdfx), I meant it to be if iszero(psi) || iszero(pdfx1), but that is actually a delicate edge case.

I now think the case in which pdfx1 is zero should be handled implicitly by what the gradlogpdf of the component yields, because x might be in the exterior of the support or on the boundary, and in this case the behavior would depend on the component distribution itself, so I will change it to simply if iszero(psi).

However, if pdfx1 is zero and the Mixture is a single distribution (e.g. MixtureModel([Beta(0.5, 0.5), [1.0])) then this would end up begin different than gradlogpdf of the distribution itself.

I added some more tests for this edge case with single mixture that will fail. I marked the PR as draft so I can look into that more carefully.

rmsrosa avatar Jan 29 '24 16:01 rmsrosa