Multithreaded sampling
I have tried to implement multithreaded sampling by changing:
function estimate_energy_with_samples(prob, samples)
#return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
logdensity_fn = Base.Fix1(LogDensityProblems.logdensity, prob)
return mean(fetch.([Threads.@spawn logdensity_fn(sample) for sample in eachsample(samples)]))
end
However, while this works when using the AutoForwardDiff() AD backend, it fails (silently) when using Zygote. I am guessing that this is due to Zygote not being thread safe here?
Code:
using AdvancedVI
using ADTypes
using DynamicPPL
using DistributionsAD
using Distributions
using ForwardDiff
using Bijectors
using Optimisers
using LinearAlgebra
using Zygote
function double_normal()
return MvNormal([2.0, 3.0, 4.0], Diagonal(ones(3)))
end
@model function normal_model(data)
p1 ~ filldist(Normal(0.0, 1.0), 2)
p2 ~ Normal(0.0, 1.0)
ps = vcat(p1, p2)
for i in 1:size(data, 2)
data[:, i] ~ MvNormal(ps, Diagonal(ones(3)))
end
end
data = rand(double_normal(), 5)
model = normal_model(data)
##
d = 3
μ = zeros(d)
L = Diagonal(ones(d));
q = AdvancedVI.MeanFieldGaussian(μ, L)
optimizer = Optimisers.Adam(1e-3)
ℓπ = DynamicPPL.LogDensityFunction(model)
elbo = AdvancedVI.RepGradELBO(10, entropy = StickingTheLandingEntropy())
q, _, stats, _ = AdvancedVI.optimize(
ℓπ,
elbo,
q,
500;
adtype = AutoZygote(),
optimizer = optimizer,
)
##
using PyPlot
fig, ax = PyPlot.subplots()
elbo = [s.elbo for s in stats]
ax.plot(elbo)
fig
1. Zygote no threading
2. Zygote with threading
3. ForwardDiff with threading
Not sure about this one. Maybe @torfjelde @willtebbutt have more insight?
Yeah I would be very surprised if Zygote.jl worked with threads like this.
You should probably look into something like Transducers.jl or something that defines a parallel way to perform a reduce (or just define your own threaded_sum(f, args...). Then once you have this, you can define a custom adjoint for it, thus hiding the threading from Zygote.jl.
I would be surprised if something like this doesn't already exist in a package, but not 100% up to date on this. Maybe @devmotion knows of one?
Thanks @torfjelde for your suggestion. I have tried implemeting a custom rule like this:
function ChainRulesCore.rrule(
::typeof(AdvancedVI.estimate_energy_with_samples), prob, samples
)
fn = Base.Fix1(LogDensityProblems.logdensity, prob)
fn_samples =
fetch.([
Threads.@spawn Zygote.pullback(fn, sample) for
sample in AdvancedVI.eachsample(samples)
])
values = [sample[1] for sample in fn_samples]
pullbacks = [sample[2] for sample in fn_samples]
function estimate_energy_with_samples_pullback(ȳ)
grads = [pullback(ȳ_i)[1] for (ȳ_i, pullback) in zip(ȳ, pullbacks)]
ret = mean(grads)
return (NoTangent(), NoTangent(), ret)
end
return mean(values), estimate_energy_with_samples_pullback
end
This works pretty well, but somehow the variance in the ELBO seems to be a bit lower with ForwardDiff:
so I am wondering if I'm doing something wrong in the custom rule. Thanks for your help!
From the snippet you shared, it doesn't seem like you're using the same RNG? If so, that could just be the cause of it.
Second thought, though this seems unlikely IMO, is that there might be numerical differences in the rules used by the two approaches. But yeah, RNG should be ruled out first.
I'll close this issue for now. Feel free to re-open this if you find that more discussion is needed.