Distributions.jl icon indicating copy to clipboard operation
Distributions.jl copied to clipboard

Add score function

Open gdalle opened this issue 3 years ago • 12 comments

In my current research, I need to use the gradient of the log-density, also known as score function. Sure, I can compute it with AD, but in many cases, this quantity has an explicit formula. Would it make sense to add it to the API?

gdalle avatar Mar 28 '22 06:03 gdalle

How can we set it up in a way that https://github.com/TuringLang/DistributionsAD.jl works and we don't need to define it for all logpdfs?

mschauer avatar Mar 28 '22 07:03 mschauer

Wait, I just found out that the score function already exists under the name gradlogpdf, but it is not in the docs :exploding_head:

julia> using Distributions

julia> gradlogpdf
gradlogpdf (generic function with 20 methods)

It doesn't seem to exist for every distribution though. Here is a code snippet for dealing with missing methods when using ChainRules:

using ChainRulesCore, Distributions

function ChainRulesCore.rrule(::typeof(logpdf), d::D, x::R) where {D<:Distribution,R<:Real}
    if hasmethod(gradlogpdf, (D,R))
        l = logpdf(d, x)
        function pullback(dl)
            return (NoTangent(), NoTangent(), dl * gradlogpdf(d, x))
        end
        return l, pullback
    else
        # existing DistributionsAD implementation
    end
end

Of course the if/else block should be replaced by dispatch, for instance trait-based. And there should be a tangent on the distribution object itself, but I don't know how to do it (yet)

gdalle avatar Mar 28 '22 08:03 gdalle

This CR definition is wrong, it is missing the derivatives with respect to d.

I would have suggested gradlogpdf but I assumed you would like to get the gradient with respect to the parameters of d. gradlogpdf computes the gradient with respect to the variate x: https://github.com/JuliaStats/Distributions.jl/blob/c9d6c28f415025bf489ac3bec2f8eec46b0eefbd/src/Distributions.jl#L266

devmotion avatar Mar 28 '22 09:03 devmotion

Yes that's what I meant by "there should be a tangent on the distribution object itself, but I don't know how to do it (yet)", I'm not super at ease with tangents of structs :sweat_smile:

However the existence of gradlogpdf actually solves my particular problem. And if we find a nice way to integrate it into DistributionsAD, as @mschauer suggests, then we may actually speed up some computations, who knows?

gdalle avatar Mar 28 '22 09:03 gdalle

In the meantime @devmotion would you be open to adding gradlogpdf to the docs, or is it voluntarily absent from the public API?

gdalle avatar Mar 28 '22 09:03 gdalle

The plan is to get rid of DistrubutionsAD rather than adding anything there. Ideally it shouldn't exist but the fixes/workarounds should be moved to Distributions and the AD packages. There are some open PRs here but they are somewhat stuck because they require breaking changes bur many people complain about breaking releases of Distributions 😅

In general, in my experience such catch-all rules are often not desirable (eg in DistributionsAD there's an ppen issue with a concrete problem caused by such a general rule). As long as AD computes the derivatives correctly and reasonably fast, no rules are needed. Even DistributionsAD does only define a few rules (and all CR stuff was already moved to Distributions). I assume also that in many cases this general definition based on gradlogpdf would be suboptimal since it does not share computations and potentially recomputes many partial results. This makes it even less likely that it's always better than letting AD do its job.

devmotion avatar Mar 28 '22 09:03 devmotion

I assume also that in many cases this general definition based on gradlogpdf would be suboptimal since it does not share computations and potentially recomputes many partial results. This makes it even less likely that it's always better than letting AD do its job.

Very good points

mschauer avatar Mar 28 '22 09:03 mschauer

Alright then, I had no idea about all of this history! I'll just draft a PR to add gradlogpdf to the Distributions doc if that's okay

gdalle avatar Mar 28 '22 11:03 gdalle

Even non-API function can need docstrings (including the information that the function is not part of the API protected by SEMVER.)

mschauer avatar Mar 28 '22 12:03 mschauer

Is there a reason not to put this one in the API though?

gdalle avatar Mar 28 '22 12:03 gdalle

Because gradlogpdf would be required to call an autodiff package to do its work in most cases, which changes the order of dependency.

mschauer avatar Mar 28 '22 13:03 mschauer

But I guess the cases that are implemented in Distributions correspond to probability measures with explicit log-density gradient. For those, I think it would make sense to at least mention that the function gradlogpdf is available, even if it is not true of all distributions

gdalle avatar Mar 28 '22 16:03 gdalle