AdvancedVI.jl
AdvancedVI.jl copied to clipboard
Basic rewrite of the package 2023 edition
Hi, this is the initial pull request for the rewrite of AdvancedVI
as a successor to #25
The following panel will be updated in real-time, reflecting the discussions happening below.
Roadmap
- [x] Change the gradient computation interface such that different algorithms can directly manipulate the gradients.
- [x] Migrate to the
LogDensityProblems
interface. - :x: ~~Migrate to
AbstractDifferentiations
.~~ Not mature enough yet. - [x] Use the
ADTypes
interface. - [x] Use
Functor.jl
for flattening/unflattening variational parameters. - [x] Add more interfaces for calling
optimize
. (see #32 ) - [x] Add pre-packaged variational families.
- [x] location-scale family
- :x: ~~Reduce memory usage of full-rank parameterization~~ (Seems like there's a unfavorable compute-memory trade-off. See this thread)
- [x] Migrate to
Optimisers.jl
. - :x: ~~Implement minibatch subsampling (probably require changes upstream, e.g.,
DynamicPPL
, too)~~ (separate issue) - [x] Add
callback
option (#5) - [x] Add control variate interface
- :x: ~~Add BBVI (score gradient)~~ not urgent
- [x] Tests
- [x] Benchmarks
- [x] Compare performance against the current version.
- ~~ Compare against competing libraries (e.g., Numpyro, Stan, and probably a bare-bones Julia, C++ implementation.~~
- :x: ~~Support GPU computation (although
Bijectors
will be a bottleneck for this.~~ (separate issue)
Topics to Discuss
- :x: ~~Should we use
AbstractDifferentiation
?~~ Not now. - :heavy_check_mark: Should we migrate to
Optimisers
? (probably yes) - :heavy_check_mark: ~~Should we call restructure inside of optimize such that the flattening/unflattening is completely abstracted out to the user? Then, in the current state of things,
Flux
will have to be added as a dependency, otherwise we'll have to roll our own implementation of destructure.~~destructure
is now part ofOptimisers
, which is much more lightweight. - :heavy_check_mark: ~~Should we keep
TruncatedADAGrad
,DecayedADAGrad
? I think these are quite outdated and would advise people from using these. So how about deprecating these?~~ Planning to deprecate.
Demo
using Turing
using Bijectors
using Optimisers
using ForwardDiff
using ADTypes
import AdvancedVI as AVI
μ_y, σ_y = 1.0, 1.0
μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0]
Turing.@model function normallognormal()
y ~ LogNormal(μ_y, σ_y)
z ~ MvNormal(μ_z, Σ_z)
end
model = normallognormal()
b = Bijectors.bijector(model)
b⁻¹ = inverse(b)
prob = DynamicPPL.LogDensityFunction(model)
d = LogDensityProblems.dimension(prob)
μ = randn(d)
L = Diagonal(ones(d))
q = AVI.MeanFieldGaussian(μ, L)
n_max_iter = 10^4
q, stats = AVI.optimize(
AVI.ADVI(prob, b⁻¹, 10),
q,
n_max_iter;
adbackend = AutoForwardDiff(),
optimizer = Optimisers.Adam(1e-3)
)
I'll have a look at the PR itself later, but for now:
AdvancedVI.jl naively reconstructs/deconstructs MvNormal from its variational parameters. This is okay from the mean-field parameterization, but for full-rank or non-diagonal covariance parameterization, this is a little more complicated since MvNormal in Distributions.jl asks for a PDMat. So the variational parameters must first be converted to a matrix, then to a PDMat, and then fed to MvNormal. For high-dimensional problems, not sure if this is ideal.
In relation to the topic above, I'm starting to believe that implementing our custom distribution (just as @theogf previously did in #25) might be a good idea in terms of performance, especially for reparameterization-trick-based methods. However, instead of reinventing the wheel (by implementing every distribution in existence) or tying ourselves to a small number of specific distributions (a custom MvNormal that is), I think implementing a single general LocationScale distribution would be feasible, where the user provides the underlying univariate base distribution. Through this, we could support distributions like the multivariate Laplace that are not even supported in Distributions.jl with a single general object.
Maybe we should make this into a discussion. I feel like there are several different approaches we can take here.
For flattening the parameters, @theogf has proposed ParameterHandling.jl. But it currently does not work well with AD. The current alternative is ModelWrappers.jlk, but it comes with many dependencies, potentially a governance topic.
For this one in particular we have an implementation in DynamicPPL that can potentially moved to its own package if we really want to: https://github.com/TuringLang/DynamicPPL.jl/blob/b23acff013a9111c8ce2c89dbf5339e76234d120/src/utils.jl#L434-L473
But this has a couple of issues:
- Requires 2n memory, since we can't release the original object (we need it as the first argument for construction since these things often depend on runtime information, e.g. the dimensionality of a
MvNormal
). - Can't specialize on which parameters we actually want, e.g. maybe we only want to learn the mean-parameter for a
MvNormal
.
(1) can be addressed by instead taking a closure-approach a la Functors.jl:
function flatten(d::MvNormal{<:AbstractVector,<:Diagonal})
dim = length(d)
function MvNormal_unflatten(x)
return MvNormal(d[1:dim], Diagonal(d[dim+1:end]))
end
return vcat(d.μ, diag(d.Σ)), MvNormal_unflatten
end
For (2), we have a couple of immediate options: a) Define "wrapper" distributions. b) Take a contextual dispatch approach.
For (a) we'd have something like:
abstract type WrapperDistribution{D<:Distribution{V,F}} <: Distribution{V,F} end
# HACK: Probably shouldn't do this.
inner_dist(x::WrapperDistribution) = x.inner
# TODO: Specialize further on `x` to avoid hitting default implementations?
Distributions.logpdf(d::WrapperDistribution, x) = logpdf(d.dist, x)
# Etc.
struct MeanParameterized{D} <: WrapperDistribution{D}
inner::D
end
function flatten(d::MeanParameterized{<:MvNormal})
μ = mean(d.inner)
function MeanParameterized_MvNormal_unflatten(x)
return MeanParameterized(MvNormal(x, d.inner.Σ))
end
return μ, MeanParameterized_MvNormal_unflatten
end
Pros:
- It's fairly simple to implement. Cons:
- Requires wrapping all the distributions all the time.
- Nice until we have other sort of nested distributions in which case this can get real ugly real fast.
For (b) we'd have something like
struct MeanOnly end
function flatten(::MeanOnly, d::MvNormal)
μ = mean(d.inner)
function MvNormal_meanonly_unflatten(x)
return MeanParameterized(MvNormal(x, d.inner.Σ))
end
return μ, MvNormal_meanonly_unflatten
end
Pros:
- Cleaner as it avoids nesting.
- Can easily support "wrapper" distributions since it can just pass the context downwards. Cons:
- Somewhat unclear to me how to make all this composable, e.g. how do we handle arbitrary structs containing distributions?
Hi @torfjelde
Maybe we should make this into a discussion. I feel like there are several different approaches we can take here.
Should we proceed here or create a separate issue?
Whatever approach we take, I think the key would be to avoid inverting or even computing the covariance matrix, provided that we operate with a Cholesky factor. None of the steps of ADVI require any of these, except for the STL estimator, where we do need to invert the Cholesky factor.
Created a discussion: https://github.com/TuringLang/AdvancedVI.jl/discussions/46
@torfjelde Hi, I have significantly changed the sketch for the project structure.
- As you previously suggested, the
ELBO
objective is now formed in a modular way. - I've also migrated to use
AbstractDifferentiation
instead of rolling our custom differentiation glue functions.
Any comments on the new structure? Also, do you approve the use of AbstractDifferentiation
?
Also, do you approve the use of AbstractDifferentiation?
@devmotion what are your current thoughts on AbstractDifferentiation
?
I've now added the pre-packaged location-scale family. Overall, to the user, the basic interface looks like the following:
μ_y, σ_y = 1.0, 1.0
μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0]
Turing.@model function normallognormal()
y ~ LogNormal(μ_y, σ_y)
z ~ MvNormal(μ_z, Σ_z)
end
model = normallognormal()
b = Bijectors.bijector(model)
b⁻¹ = inverse(b)
prob = DynamicPPL.LogDensityFunction(model)
d = LogDensityProblems.dimension(prob)
μ = randn(d)
L = Diagonal(ones(d))
q = AVI.MeanFieldGaussian(μ, L)
λ₀, restructure = Flux.destructure(q)
function rebuild(λ′)
restructure(λ′)
end
λ = AVI.optimize(
AVI.ADVI(prob, b⁻¹, 10),
rebuild,
10000,
λ₀;
optimizer = Flux.ADAM(1e-3),
adbackend = AutoForwardDiff()
)
q = restructure(λ)
μ = q.transform.outer.a
L = q.transform.inner.a
Σ = L*L'
μ_true = vcat(μ_y, μ_z)
Σ_diag_true = vcat(σ_y, diag(Σ_z))
@info("VI Estimation Error",
norm(μ - μ_true),
norm(diag(Σ) - Σ_diag_true),)
Some additional notes to the comments above,
- Should we call
restructure
inside ofoptimize
such that the flattening/unflattening is completely abstracted out to the user? Then, in the current state of things,Flux
will have to be added as a dependency, otherwise we'll have to roll our own implementation ofdestructure
. - Should we keep
TruncatedADAGrad
,DecayedADAGrad
? I think these are quite outdated and would advise people from using these. So how about deprecating these? - We should probably migrate to
Optimisers.jl
. The current optimization infrastructure is quite old.
@devmotion what are your current thoughts on AbstractDifferentiation?
It's the package everyone seems to want to use at some point in the future but basically nobody uses right now because it has severe limitations and performance problems, compared with using the AD systems more directly. These issues are explained and discussed in a few of the open bugs in the repo, but the main problem is that nobody (who knows how to fix them - which I assume are quite a few people, so that's not the limiting factor) seems to have to time and money to work on them.
But yeah, these issues are the main reason why we don't use it in SciML (yet) and why e.g. LogDensityProblems does not use it.
(Generally, I also think it would be good to aim for consistency in the Turing ecosystem and move to a more unified AD experience. I think I've mentioned before that it would be good to extract the AD types to a common package and reuse them - but maybe we could also just re-use https://github.com/SciML/ADTypes.jl. This would complement our LogDensityProblems interfaces since (I think I mentioned this in the LogDensityProblemsAD repo also at some point) the Val
+ keyword args interface is flexible enough but does not provide a way to store and pass around these AD options in a compact way. Maybe the more direct approach would be if LogDensityProblems would adopt ADTypes - as long as AbstractDifferentiation is not ready, which is the ultimate goal currently.)
@devmotion That's unfortunate... One thing I noted for AdvancedVI.jl
is that we can't reuse all of the nice machinery behind LogDensityProblemsAD
, and rolling our own AD interface definitely does not sound very ideal. On the other hand, it seems to me that, LogDensityProblemsAD
should work pretty well for the remaining Turing ecosystem. So I think it's really AdvancedVI.jl
that needs its own solution.
So I think it's really AdvancedVI.jl that needs its own solution.
Could we at least share a common user interface for specifying the AD backend (i.e., ADTypes
or something like it more aligned to our needs if necessary)?
ADTypes
looks quite sensible to unify AD interfaces in Turing gradually. It is very lightweight (no dependencies) and has a clear scope. We can help improve it if certain features are missing.
@zuhengxu feel free to add your comments and suggestions here since the NormalisingFlows
package will likely benefit from this rewrite.
Thanks for looping me in! I'm putting my comments here (my understanding to those aforementioned problems is superfluous so let me know if I said anything incorrect).
It's funny that I just took the identical way by using the AbstractDifferentiation.jl
for the ease of switching between different AD systems. But as explained by @devmotion , AbstractDifferentiation.jl
may lead to some performance difficiencies. I haven't dig into those “severe limitation and performance issue”, but one thing I did realize is that we are missing many nice configuration to some of the AD (e.g., compiled tape for ReverseDiff.jl
). Is it one of the main issue with AbstractDifferentiation.jl
?
My second concern is about Flux.destructure
---the current params flattening method in NormalizingFlows.jl
.
As i mentioned in https://github.com/TuringLang/NormalizingFlows.jl/pull/1/commits/5a7df3dd5461979f9643444cb4fa609068df2406#r1225209163, this is ridiculously slow when desctructuring a long flow; same problem applies to the flatten(x)
function suggested by @torfjelde. For reference, the excution of the following code took over an hour and I just kill the program for the lack of patience,
using Bijectors, Flux, Random, Distributions, LinearAlgebra
function create_planar_flow(n_layers::Int, q₀)
d = length(q₀)
Ts = ∘([PlanarLayer(d) for _ in 1:n_layers]...)
return transformed(q₀, Ts)
end
flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I)) # create 20 layers planar flow
theta, re = Flux.destructure(flow) # this is sooooooo slow
I wonder if we have better ways of handling the parameter flattening/unflattening, with better performance without lossing support to those AD systems (at least ForwardDiff, ReverseDiff, Zygote).
I'm happy to go with ADTypes.jl, but then we need to propagate this through the rest of the Turing.jl ecosystem.
And I agree that it would nice if also LogDensityProblemsAD.jl started usin ghe same, but I don't know if that will have pushback or not.
For reference, the excution of the following code took over an hour and I just kill the program for the lack of patience
This doesn't have anything to do with the usage of destructure; it's because you're doing ∘(...)
.
This then causes a ComposedFunction
which is n_layers
nested (ComposedFunction(ComposedFunction(ComposedFunction(...)), ...)
; you get the gist). This blows up the compilation times.
For "large" compilations, you should instead use something like https://github.com/oschulz/FunctionChains.jl. Then you could use a Vector
for to represent the composition which would make the compilation time waaaaay better (because now "iterate" over the recursion at runtime rather than at compile-time).
I also wonder if we could deal with this by subtyping Function
due to https://docs.julialang.org/en/v1/devdocs/functions/#compiler-efficiency-issues :thinking:
But can discuss this issue more somewhere else :+1:
What's the reasoning behind moving away from the current design here?
In particular I mean, why don't we keep the algorithms and stuff like
struct ELBO <: VariationalObjective end function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...) return elbo(Random.default_rng(), alg, q, logπ, num_samples; kwargs...) end
? Maybe you've already spoken about this somewhere; if so, you can just point me there:)
But IMO the current design is still the right direction to go; it just needs more work.
Hi @torfjelde I think I brought it up in the first proposal. The gist is that, VI is basically a gradient estimator/objective
+ SGD algorithm
. So the additional distinction for an algorithm seems redundant to me. (This is also true for the current codebase since it doesn't do anything.) I think most of the things that I can currently imagine can be handled within the AbstractVariationalObjective
abstraction. But we could re-introduce it in the future if it turns out to be necessary modularity.
Hi @torfjelde , I've added the interface for stateful objectives. Like Optimisers.jl
, it's supposed to be passed around as a tuple. For ADVI, I think only the control variates will need states, so that's what's currently done.
Please let me know what you think about this. I think we could settle this before moving further. As for me, I'm unsure whether we'll need the algorithm
distinction. Do you have any specific use in mind?
Once we get this done, I'll start adding docstrings, tests, and then I think this PR could be changed from draft to a proper PR since it will subsume the current codebase.
Hi @torfjelde I think I brought it up in the first proposal.
The gist is that, VI is basically a gradient estimator/objective + SGD algorithm.
But what about, say, changing certain parameters during the optimization process? E.g. what if I want to change the number of samples used for the ELBO estimate based on the number of steps we've taken or the gradient magnitude? Do I put this in the /objective/? That seems a bit weird, no?
Wouldn't it make sense to put this in some "configuration" object? Thats is effectively what the algorithm is.
So the additional distinction for an algorithm seems redundant to me.
You've added different estimators to the "objective" ADVI
+ control variaties. But this is generally not what ADVI
means, right?
If I were to use ADVI
, I'd expect it to be a simple transformation from constrained to unconstrained + MVNormal
, making use of the closed-form entropy of the MvNormal
. That's it. I agree you can then start to make changes to the objective, e.g. using a StraightThroughEstimator
, etc., but I'd be opposed to referring to this as ADVI
.
Of course you could remove this from the ADVI
objective you have, and then create a separate one, which does basically the same as ADVI
, but with the option of straight-through estimator. But because we don't have a notion of "algorithm"/configuration, now we have to once more implement the ELBO but in a slightly different way.
By introducing a form of "config", e.g. AbstractAlgorithm
, you can define further behaviors that extend beyond the objective, etc. You're right that in the end, it really just comes down to estimation of the objective, but ideally we'd maximize the code-sharing between the different approaches.
Honestly, might be worth having a brief call about this or something? I'm worried that coupling the configuration and the objective is unnecessarily restrictive going forward.
I think most of the things are there except for convergence diagnostics/early termination stuff, which I think isn't very urgent.
I think most of the things are there except for convergence diagnostics/early termination stuff, which I think isn't very urgent.
Thanks, @Red-Portal -- leaving convergence diagnostics and other new functionality as future work sounds good. Can you fix the failing CI tests?
Hi @yebai, One of the failing tests is a critical bug on ReverseDiff
's side. I could nudge the test so that it doesn't fail anymore.
That looks like a bad bug since it involves common linear algebra operations; it's ok to mark the ReverseDiff test as broken for now. There is also a ForwarDiff error. Can you fix that?
@yebai Those are coming from the fp32 tests against a full-rank Gaussian posterior. I think those are not critical; Maybe I'll just remove fp32 tests.
@yebai Those are coming from the fp32 tests against a full-rank Gaussian posterior. I think those are not critical; Maybe I'll just remove fp32 tests.
Often the code runs successfully on x64 platforms but fails on x86. This is caused by inconsistent element types in the code. It would be good to identify and fix them.
@yebai I think the x86 error is a ReTest
bug. I found that the failing function takes an Int64
but, the caller shoves in an integer literal, which is an Int32
on x86. I'll file an issue. (update: filed here)
Codecov Report
Patch coverage: 80.83%
and project coverage change: +19.72%
:tada:
Comparison is base (
f84a306
) 61.11% compared to head (de4284e
) 80.83%. Report is 1 commits behind head on master.
Additional details and impacted files
@@ Coverage Diff @@
## master #45 +/- ##
===========================================
+ Coverage 61.11% 80.83% +19.72%
===========================================
Files 9 9
Lines 144 167 +23
===========================================
+ Hits 88 135 +47
+ Misses 56 32 -24
Files Changed | Coverage Δ | |
---|---|---|
ext/AdvancedVIEnzymeExt.jl | 0.00% <0.00%> (ø) |
|
src/AdvancedVI.jl | 10.00% <10.00%> (-62.00%) |
:arrow_down: |
src/objectives/elbo/entropy.jl | 70.00% <70.00%> (ø) |
|
src/objectives/elbo/advi.jl | 80.00% <80.00%> (ø) |
|
ext/AdvancedVIForwardDiffExt.jl | 87.50% <87.50%> (ø) |
|
src/distributions/location_scale.jl | 89.85% <89.85%> (ø) |
|
ext/AdvancedVIReverseDiffExt.jl | 100.00% <100.00%> (ø) |
|
ext/AdvancedVIZygoteExt.jl | 100.00% <100.00%> (ø) |
|
src/optimize.jl | 100.00% <100.00%> (ø) |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
I'll also remove the tests for FullMonteCarlo
since it's useless.
Pull Request Test Coverage Report for Build 5946033255
- 133 of 165 (80.61%) changed or added relevant lines in 9 files are covered.
- No unchanged relevant lines lost coverage.
- Overall coverage increased (+19.5%) to 80.606%
Changes Missing Coverage | Covered Lines | Changed/Added Lines | % |
---|---|---|---|
ext/AdvancedVIForwardDiffExt.jl | 7 | 8 | 87.5% |
src/objectives/elbo/entropy.jl | 7 | 10 | 70.0% |
src/objectives/elbo/advi.jl | 23 | 28 | 82.14% |
ext/AdvancedVIEnzymeExt.jl | 0 | 7 | 0.0% |
src/distributions/location_scale.jl | 55 | 62 | 88.71% |
src/AdvancedVI.jl | 3 | 12 | 25.0% |
<!-- | Total: | 133 | 165 |
Totals | |
---|---|
Change from base Build 5324928173: | 19.5% |
Covered Lines: | 133 |
Relevant Lines: | 165 |
💛 - Coveralls
Currently in the works of enabling the inference tests for Zygote
, which is currently broken for Stacked
bijectors
Would it be possible to split this PR into a sequence of PRs @Red-Portal ? The current version is somewhat difficult to review properly due to the number of changes made :confused: