AdvancedVI.jl icon indicating copy to clipboard operation
AdvancedVI.jl copied to clipboard

Multithreaded sampling

Open arnauqb opened this issue 1 year ago • 4 comments

I have tried to implement multithreaded sampling by changing:

function estimate_energy_with_samples(prob, samples)
    #return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
    logdensity_fn = Base.Fix1(LogDensityProblems.logdensity, prob)
    return mean(fetch.([Threads.@spawn logdensity_fn(sample) for sample in eachsample(samples)]))
end

However, while this works when using the AutoForwardDiff() AD backend, it fails (silently) when using Zygote. I am guessing that this is due to Zygote not being thread safe here?

Code:

using AdvancedVI
using ADTypes
using DynamicPPL
using DistributionsAD
using Distributions
using ForwardDiff
using Bijectors
using Optimisers
using LinearAlgebra
using Zygote

function double_normal()
    return MvNormal([2.0, 3.0, 4.0], Diagonal(ones(3)))
end

@model function normal_model(data)
    p1 ~ filldist(Normal(0.0, 1.0), 2)
    p2 ~ Normal(0.0, 1.0)
    ps = vcat(p1, p2)
    for i in 1:size(data, 2)
        data[:, i] ~ MvNormal(ps, Diagonal(ones(3)))
    end
end

data = rand(double_normal(), 5)
model = normal_model(data)

##

d = 3
μ = zeros(d)
L = Diagonal(ones(d));
q = AdvancedVI.MeanFieldGaussian(μ, L)
optimizer = Optimisers.Adam(1e-3)

ℓπ = DynamicPPL.LogDensityFunction(model)
elbo = AdvancedVI.RepGradELBO(10, entropy = StickingTheLandingEntropy())

q, _, stats, _ = AdvancedVI.optimize(
	ℓπ,
	elbo,
	q,
	500;
	adtype = AutoZygote(),
	optimizer = optimizer,
)

##
using PyPlot
fig, ax = PyPlot.subplots()
elbo = [s.elbo for s in stats]
ax.plot(elbo)
fig

1. Zygote no threading

plot_3

2. Zygote with threading

plot_5

3. ForwardDiff with threading

plot_6

arnauqb avatar Oct 12 '24 12:10 arnauqb

Not sure about this one. Maybe @torfjelde @willtebbutt have more insight?

Red-Portal avatar Oct 12 '24 18:10 Red-Portal

Yeah I would be very surprised if Zygote.jl worked with threads like this.

You should probably look into something like Transducers.jl or something that defines a parallel way to perform a reduce (or just define your own threaded_sum(f, args...). Then once you have this, you can define a custom adjoint for it, thus hiding the threading from Zygote.jl.

I would be surprised if something like this doesn't already exist in a package, but not 100% up to date on this. Maybe @devmotion knows of one?

torfjelde avatar Oct 15 '24 11:10 torfjelde

Thanks @torfjelde for your suggestion. I have tried implemeting a custom rule like this:

function ChainRulesCore.rrule(
    ::typeof(AdvancedVI.estimate_energy_with_samples), prob, samples
)
    fn = Base.Fix1(LogDensityProblems.logdensity, prob)
    fn_samples =
        fetch.([
            Threads.@spawn Zygote.pullback(fn, sample) for
            sample in AdvancedVI.eachsample(samples)
        ])
    values = [sample[1] for sample in fn_samples]
    pullbacks = [sample[2] for sample in fn_samples]
    function estimate_energy_with_samples_pullback(ȳ)
        grads = [pullback(ȳ_i)[1] for (ȳ_i, pullback) in zip(ȳ, pullbacks)]
        ret = mean(grads)
        return (NoTangent(), NoTangent(), ret)
    end
    return mean(values), estimate_energy_with_samples_pullback
end

This works pretty well, but somehow the variance in the ELBO seems to be a bit lower with ForwardDiff:

plot_34

so I am wondering if I'm doing something wrong in the custom rule. Thanks for your help!

arnauqb avatar Oct 21 '24 15:10 arnauqb

From the snippet you shared, it doesn't seem like you're using the same RNG? If so, that could just be the cause of it.

Second thought, though this seems unlikely IMO, is that there might be numerical differences in the rules used by the two approaches. But yeah, RNG should be ruled out first.

torfjelde avatar Oct 22 '24 15:10 torfjelde

I'll close this issue for now. Feel free to re-open this if you find that more discussion is needed.

Red-Portal avatar Nov 22 '25 15:11 Red-Portal