DistributionsAD.jl
DistributionsAD.jl copied to clipboard
Reparameterization attached to `Distribution`
Overview
Since a distributions has to be re-implemented here and the focus is on AD, I was wondering if it would be of any interested to add reparameterization to Distribution. In AD-context you usually want to work in ℝ (unconstrained) rather than constrained space, e.g. optimizing parameters for a Distribution.
A simple example is Normal(μ, σ). One might want to perform an maximum likelihood estimate (MLE) of μ and σ by gradient descent (GD). This requires differentiating the logpdf wrt. μ, σ and then updating the parameters of the Normal accordingly. But for the distribution to be valid we simultaneously need to ensure that σ > 0. Usually we accomplish this by instead differentiating the function
(μ, logσ) -> logpdf(Normal(μ, exp(logσ)), x)
# instead of
(μ, σ) -> logpdf(Normal(μ, σ), x)
The proposal is to also allow something like
reparam(μ, σ) = μ, exp(σ)
Normal(μ, logσ, reparam)
which in the MLE case allows us to differentiate
(μ, σ) -> logpdf(Normal(μ, σ, reparam), x)
Why?
As you can see, in the case of a univariate Normal this doesn't offer much advantage of the current approach. But the current approach is a subclass of what we can then support (by letting reparam equal identity) and I believe there certainly are cases where this is very useful:
- Distributions with parameters consisting of
Arraycan be updated in-place rather than by reconstruction of the distribution. - Specialized implementations of different parameterizations can be implemented for possible performance gain
- Abstracting away optimization from the user becomes significantly easier. Take the MLE estimate using GD again; if we were to wrap this entire process in some
mlefunction we'd require the user to also provide the transformation ofσas an argument. If there are a lot of functions depending on this parameterization, it quickly becomes tedious and a bit difficult (speaking from experience) to remember to pass and perform the transformation in each such function. Alternatively you pass around the unconstrained parameters as additional parameters, but again, tedious and you still need to ensure you perform the transformation in each method. For an example, see the impl of ADVI in Turing.jl: https://github.com/TuringLang/Turing.jl/blob/bc7e5b643abad9529b99c24caac6dbce6a562ad2/src/variational/advi.jl#L74-L77, https://github.com/TuringLang/Turing.jl/blob/bc7e5b643abad9529b99c24caac6dbce6a562ad2/src/variational/advi.jl#L88, https://github.com/TuringLang/Turing.jl/blob/bc7e5b643abad9529b99c24caac6dbce6a562ad2/src/variational/advi.jl#L119. - IMO, this works much better with the Tracker.jl and Flux.jl "framework"/approach. At the moment one cannot do something like
Normal(param(μ), param(σ)),Tracker.back!through a computation depending onμ, σ, and then update parameters. If one did this naively, it's possible to step intoσ < 0region. For very complex cases I think it's easier to attach the parameters to the structs which depend on them, rather than putting everything into a huge array and then packing/unpacking everywhere. We can then useFlux.@treeliketo further simplify our lives. The below example show an example which arises in things like auto-encoders:
W, b = param(W_init), param(b_init)
μ, logσ = param(μ_init), param(logσ_init)
for i = 1:num_steps
nn = Dense(W, b)
d = MvNormal(μ, exp(σ)) # Diagonal MvNormal
x = rand(d)
y = nn(x)
# Do computation using `y`
# ...
Tracker.back!(...)
update!(W, b)
update!(μ, logσ)
end
# VS.
Flux.@treelike MvNormal
nn = Dense(param(W_init), param(b_init))
d = MvNormal(param(μ_init), param(σ_init), (μ, σ) -> (μ, exp.(σ)))
for i = 1:num_steps
x = rand(d)
y = nn(x)
# Do computation using `y`
# ...
Tracker.back!(...)
update!(Flux.params(nn)) # more general
update!(Flux.params(d))
end
Example implementation
using Distributions, StatsFuns, Random
abstract type ParameterizedDistribution{F, S, P} <: Distribution{F, S} where {P} end
# maybe?
transformation(::ParameterizedDistribution{F, S, P}) where {F, S, P} = P
struct NormalAD{T<:Real, P} <: ParameterizedDistribution{Univariate, Continuous, P}
μ::T
σ::T
end
NormalAD(μ::T, σ::T) where {T<:Real} = NormalAD{T, identity}(μ, σ)
NormalAD(μ::T, σ::T, f::Function) where {T<:Real} = NormalAD{T, f}(μ, σ)
# convenience; probably don't want to do this in an actual implementation
Base.identity(args...) = identity.(args)
function Distributions.logpdf(d::NormalAD{<:Real, P}, x::Real) where {P}
μ, σ = P(d.μ, d.σ)
z = (x - μ) / σ
return -(z^2 + log2π) / 2 - log(σ)
end
function Distributions.rand(rng::AbstractRNG, d::NormalAD{T, P}) where {T, P}
μ, σ = P(d.μ, d.σ)
return μ + σ * randn(rng)
end
julia> # Standard: μ ∈ ℝ, σ ∈ ℝ⁺
d1 = NormalAD(0.0, 1.0)
NormalAD{Float64,identity}(μ=0.0, σ=1.0)
julia> d2 = Normal(0.0, 1.0)
Normal{Float64}(μ=0.0, σ=1.0)
julia> x = randn()
-0.028232023381049923
julia> logpdf(d1, x) == logpdf(d2, x)
true
julia> # Real-valued: μ ∈ ℝ, σ ∈ ℝ using `exp`
d3 = NormalAD(0.0, 0.0, (μ, σ) -> (μ, exp(σ)))
NormalAD{Float64,getfield(Main, Symbol("##3#4"))()}(μ=0.0, σ=0.0)
julia> logpdf(d3, x) == logpdf(d2, x)
true
julia> # Real-valued: μ ∈ ℝ, σ ∈ ℝ using `softplus`
d4 = NormalAD(0.0, invsoftplus(1.0), (μ, σ) -> (μ, softplus(σ)))
NormalAD{Float64,getfield(Main, Symbol("##9#10"))()}(μ=0.0, σ=0.541324854612918)
julia> logpdf(d4, x) == logpdf(d2, x)
true
Together with Tracker.jl
julia> using Tracker
julia> μ = param(0.0)
0.0 (tracked)
julia> σ = param(0.0)
0.0 (tracked)
julia> d_tracked = NormalAD(μ, σ, (μ, σ) -> (μ, exp(σ)))
NormalAD{Tracker.TrackedReal{Float64},getfield(Main, Symbol("##5#6"))()}(μ=0.0 (tracked), σ=0.0 (tracked))
julia> lp = logpdf(d_tracked, x)
-0.9193370567767668 (tracked)
julia> Tracker.back!(lp)
julia> Tracker.grad.((d_tracked.μ, d_tracked.σ))
(-0.028232023381049923, -0.9992029528558118)
julia> x = rand(d_tracked)
-1.6719800201542028 (tracked)
julia> Tracker.back!(x)
julia> Tracker.grad.((d_tracked.μ, d_tracked.σ))
(0.9717679766189501, -2.6711829730100147)
Alternative approach: wrap Distribution
An alternative approach is to do something similar to TransformedDistribution in Bijectors.jl where you simply wrap a distribution in the instance. Then you could require the user to provide a reparam method which takes what's returned from Distributions.params(d::Distribution) and applies the reparameterization correctly.
This requires signfinicantly less work, but isn't as nice nor as easy to extend/work with IMO.
I don't really understand why this has to be tied to a distributions library. Wouldn't it be more straightforward / useful to have this as an orthogonal thing that just plays nicely with distributions? I had imagined something along the lines of an interface like
a_positive, a_unconstrained = positive(inv_link_or_link_or_whatever, a_positive_init)
Then we're just talking about the generic parameter handling / transformation problem, rather than anything inherently probabilistic.
Also, could we please try to think about how this plays with Zygote, rather than Tracker, as Tracker's day are numbered?
Oops, didn't mean to close
I don't really understand why this has to be tied to a distributions library. Wouldn't it be more straightforward / useful to have this as an orthogonal thing that just plays nicely with distributions?
I see what you're saying, but I think it's just too closely related. And I think it's not far-fetched to say that "reparameterization of a Distribution is related to Distributions.jl"? Also in some cases it can simplify certain computations, e.g. entropy for a DiagMvNormal using exp to enforce positive-constraint on variance. And my main motivation is that you end up performing the transformations "behind the scenes" rather than the user having to do this in every method that needs it. You do it right once in the implementation of the Distribution and then no more. And the standard case is simply an instance of the more general reparametrizable Distribution, so the user who doesn't care doesn't have to care. Other than more work, I think the only downside is that it's more difficult to perform checks as to whether or not the parameters are valid.
Also, could we please try to think about how this plays with Zygote, rather than Tracker, as Tracker's day are numbered?
But I think Zygote also intends to support AD wrt. parameters of a struct, right? I can't find the issue right now, but I swear I saw @MikeInnes discussing something like this somewhere. If so, I think my argument using Tracker.jl still holds?
I haven't followed this issue carefully but (1) yes, Zygote supports structs well and (2) it'd be nice not to have to load DistributionsAD on top of Distributions to get AD to work (not sure if that's the plan). Happy to look at support directly in Zygote, maybe via requires, if that's an option.
A few comments I have.
- Doing constrained optimization by transforming the constrained variables is just one way of doing constrained optimization. There are optimization algorithms that can efficiently handle box constraints, semidefinite constraints, linear constraints, etc.
- I think doing the re-parameterization of the constrained parameters at the optimization/differentiation layer, not the distribution layer, is the better approach in many cases at no loss of efficiency, e.g.
x -> logpdf(Normal(1.0, exp(x)), 1.0)is pretty efficient. - However, I also see the need for being able to construct a distribution using different parameters, e.g. precision vs covariance matrix, or directly using a triangular matrix which could be the Cholesky of the covariance or precision. I think these should be possible with multiple dispatch. Providing things like
MvNormal(mu, Covariance(A))orMvNormal(mu, Precision(A)). IfAis aCholeskywe can also construct thePDMatdirectly. With these more efficient constructors, we get the triangular re-parameterization for free, e.g.L -> logpdf(MvNormal(mu, Covariance(Cholesky(L, 'L', 0))), x). I believe the distribution (re-)construction in this case should not allocate since we are not factorizingA.
Since we are discussing changes to Distributions, pinging @matbesancon.
- I think doing the re-parameterization of the constrained parameters at the optimization/differentiation layer, not the distribution layer, is the better approach in many cases at no loss of efficiency, e.g. x -> logpdf(Normal(1.0, exp(x)), 1.0) is pretty efficient.
That's true, but in multivariate cases you still cannot do inplace updates the parameters (though to allow this you'd have to take a slightly different approach to certain distributions than what Distributions.jl is currently doing, e.g. MvNormal assumes the covariance matrix is constant so the Cholesky decomp will be performed once upon construction).
It also doesn't solve the issue of "interoperability" with the parts of the ecosystem in which Distributions.jl is often used, e.g. with Tracker/Zygote. It of course works, but for larger models it can be quite a hassle compared to tying the parameters to the Distribution instance rather than keeping track of it through variables outside of the Distribution.
- However, I also see the need for being able to construct a distribution using different parameters, e.g. precision vs covariance matrix, or directly using a triangular matrix which could be the Cholesky of the covariance or precision. I think these should be possible with multiple dispatch. Providing things like MvNormal(mu, Covariance(A)) or MvNormal(mu, Precision(A)). If A is a Cholesky we can also construct the PDMat directly. With these more efficient constructors, we get the triangular re-parameterization for free, e.g. L -> logpdf(MvNormal(mu, Covariance(Cholesky(L, 'L', 0))), x). I believe the distribution (re-)construction in this case should not allocate since we are not factorizing A.
I think MvNormal already does this, no? But what is the difference between this and the more general approach of allowing "lazy" transformations like what this issue is proposing? It seems, uhmm, maybe a bit arbitrary to allow reparameterizations, but only for Cholesky and Precision? I understand you could do this for more reparameterizations, e.g. define Normal(μ, Exp(σ)) and so on, but this will require even more work and be less flexible than what this issue is proposing, right?
That's true, but in multivariate cases you still cannot do inplace updates the parameters (though to allow this you'd have to take a slightly different approach to certain distributions than what Distributions.jl is currently doing, e.g. MvNormal assumes the covariance matrix is constant so the Cholesky decomp will be performed once upon construction)
I think this is one of the key aspects of this discussion. I'm personally more of a fan of the functional approach, but I appreciate that there are merits to both approaches. I'm not really sure which way the community is leaning here, perhaps @MikeInnes or @oxinabox can comment? If I remember correctly, Zygote's recommended mode of operation is now the functional style?
I started out preferring the more functional style, but have recently grown quite fond of the Flux approach. Granted, I've recently been using more neural networks where I think this approach is particularly useful.
Also, it's worth noting that with what's proposed here you can do both (which is why I like it!:) )
On an earlier point:
I haven't followed this issue carefully but (1) yes, Zygote supports structs well and (2) it'd be nice not to have to load DistributionsAD on top of Distributions to get AD to work (not sure if that's the plan). Happy to look at support directly in Zygote, maybe via requires, if that's an option.
Had this discussion with @matbesancon
In context of ChainRules.
My recollection is that
While he happy about AD for derivatives, he absolutely does not what it in Distribution.jl
ChainRules.jl (not ChainRulesCore) is just adding @requires for these cases.
(Rewriting this for ChainRules is still a little way off, meet to continue improving struct support for that, I think)
I think MvNormal already does this, no? But what is the difference between this and the more general approach of allowing "lazy" transformations like what this issue is proposing? It seems, uhmm, maybe a bit arbitrary to allow reparameterizations, but only for Cholesky and Precision? I understand you could do this for more reparameterizations, e.g. define Normal(μ, Exp(σ)) and so on, but this will require even more work and be less flexible than what this issue is proposing, right?
@torfjelde You can still do lazy transformations by multiple dispatch, like you said using Normal(μ, Exp(σ)) for example. For MvNormal, we can also do MvNormal(μ, Exp(Σ)) which internally also stores lazy wrapper of Σ and dispatches to efficient v' Exp(Σ)^-1 v and logdet(Exp(Σ)) where possible. For example, logdet(Exp(Σ)) = tr(Σ).
Dispatching on reparam in your proposal for efficient tricks like this is only possible if reparam itself uses the lazy Exp internally and we dispatch on Exp for logdet. So if we can avoid making our own AD types using the lazy wrapper approach directly, that would be better.
If we are talking modifying the distribution in-place (no AD), we can do that using the lazy function wrapper. Note that we always have to define the fields of the distribution according to its struct definition. So we have one of two scenarios:
- We tap into the inner most constructors for
PDMatandMvNormalfor example to define our distributiondistonce while keeping the handle toΣthat we can modify in-place outside affecting the nextlogpdf(dist, x)result. - We call an outer constructor that does copying, linear algebra, or call other functions that render our handles to
Σindependent from the distribution struct returned.
This is fundamentally a constructor definition problem. It is a question of how we can construct the distribution while enabling in-place modification of the inputs. Lazy function wrappers take us some of the way. Note that at the end of the day we still need to satisfy the field type signature of the distribution struct, so we may need to modify the type parameters of the distribution struct to accept more generic matrix types like a lazy matrix-valued function which sub-types AbstractMatrix. Learning to live within those boundaries and pushing them where it makes sense to enable dispatch-based laziness seems like a more Julian approach to me than making 2 versions of the same struct, one persistent and one lazy.
So in summary, the anonymous function and dispatch-based laziness approach enables us to:
- Think about ways to make various functions more efficient, e.g.
logdet, - Avoid the need for an AD version of every distribution,
- Keep handles to the inputs passed to the outer constructor if we get laziness right, which enables in-place modification.
Note that at this point, it is not a question of whether we need arbitrary re-parameterization, just the API choice. I am leaning towards not having a struct for every distribution for AD purposes only, using anonymous functions and dispatch-based laziness to gain any efficiency and/or flexibility benefits. Ironically, we already implement an AD distribution for MvNormal here to workaround some Distributions-PDMats complexity. But for a long-term solution we should try to live within the boundaries of Distributions.jl and PDMats.jl.
Pinging @ChrisRackauckas in case he has opinions on this.
@torfjelde You can still do lazy transformations by multiple dispatch, like you said using Normal(μ, Exp(σ)) for example. For MvNormal, we can also do MvNormal(μ, Exp(Σ)) which internally also stores lazy wrapper of Σ and dispatches to efficient v' Exp(Σ)^-1 v and logdet(Exp(Σ)) where possible. For example, logdet(Exp(Σ)) = tr(Σ).
Yeah, I understood that but it would still require always building an explicit type Exp which could do this, in constrast to the user just passing in the exp function and we wrap every use of σ in this (this approach wouldn't just work for any case, but in univariate case it would be "one impl works for alll transformations").
But after reading your comment I realize we can just make a Lazy{exp}(σ) wrapper of σ and do the same thing as I wanted to do:) (You might have already realized this!) This is basically a "you know what you're doing"-type. Well, it's going to be rather annoying to have to specify different behavior on all combinations of the different parameters, .e.g. you want to apply log to μ and exp to σ you have to implement Normal{Log, Exp}, Normal{<:Real, Exp} and Normal{Log, <:Real} in addition to existing implementation. Granted, the same issue is a problem in what I'm proposing if you require a separate transform for each parameter and you want to do specific behavior for exp on σ.
I think I'm coming around to your suggestion!:) It still seems like making this compatible with current Distribution is going to be, uhmm, slightly challenging.
Dispatching on reparam in your proposal for efficient tricks like this is only possible if reparam itself uses the lazy Exp internally and we dispatch on Exp for logdet. So if we can avoid making our own AD types using the lazy wrapper approach directly, that would be better.
You could still do this when P is, say, the actual function exp though, right? But maybe this has some issues I'm not fully aware of.
Learning to live within those boundaries and pushing them where it makes sense to enable dispatch-based laziness seems like a more Julian approach to me than making 2 versions of the same struct, one persistent and one lazy.
"to where it makes sense" -> "to where we can" seems like a more accurate statement :upside_down_face:
You could still do this when P is, say, the actual function exp though, right? But maybe this has some issues I'm not fully aware of.
Well in your proposal IIUC, P is acting on all the arguments together not each one individually. So we don't really know that it is using exp on the covariance inside from its type only to do any magical specialization on P. This means we still need to rely on Exp for the dispatch-based lazy specialization of logdet for example.
"to where it makes sense" -> "to where we can" seems like a more accurate statement 🙃
True, but if we hit a wall, we can decide to temporarily branch off until the obstacle is removed. This is what we do now for MvNormal and arguably with this whole package.
But after reading your comment I realize we can just make a Lazy{exp}(σ) wrapper of σ and do the same thing as I wanted to do:)
Yes this is a nice generic way of defining lazy wrappers. Exp can be alias for Lazy{exp}.
I completely agree with the last comment:)
One thing though: this "wrapping"-approach means that if we want type-stability we'd have to allow all the parameters of a distribution to take on different types, e.g. Beta(a::T, b::T) can't be used since you might want to do Beta(a::Lazy{T, f1}, b::Lazy{T, f2}) where f1 and f2 are two different functions.
It still seems like the best approach, but worth noting that this might be a big hurdle to overcome as we'd basically need to re-define most distributions to accomodate something like this.
And I think something like the following works okay as a "default" where we just allow the type itself to specify how to handle the unconstrained-to-constrained transformation:
abstract type Constrained end
struct Unconstrained{T} <: Constrained
val::T
end
value(c::Unconstrained) = c.val
Normal(μ::Unconstrained{T}, σ::Unconstrained{T}) where {T} = Normal(value(μ), exp(value(σ)))
Could also do something like Unconstrained{T, F} where F is a callable. Then we can use
value(c::Unconstrained{T, F}) where {T, F} = F(c.val)
# when `F = identity` we have a default treatment
Normal(μ::Unconstrained{T, identity}, σ::Unconstrained{T, identity}) where {T} = Normal(value(μ), exp(value(σ)))
# in this case we have to assume that `value` takes care of the transformation
Normal(μ::Unconstrained{T}, σ::Unconstrained{T}) where {T} = Normal(value(μ), value(σ))
Need to think about this further, but doesn't seem like a horrible approach.