Add Stochastic Gradient HMC
Part of #60
a bump on this
AdvancedHMC.jl documentation for PR #428 is available at: https://TuringLang.github.io/AdvancedHMC.jl/previews/PR428/
The basic SGHMC algorithm is intuitive, but in the AbstractMCMC tests, there are Inf in the final result after being transformed back to the original space.Gonna need some advice on what's going wrong here @sunxd3
let me take a look, I haven't deal with this part of the code before, so give me bit of time.
The following version passes the test
function AbstractMCMC.step(
rng::AbstractRNG,
model::AbstractMCMC.LogDensityModel,
spl::SGHMC,
state::SGHMCState;
n_adapts::Int=0,
kwargs...,
)
if haskey(kwargs, :nadapts)
throw(
ArgumentError(
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
),
)
end
i = state.i + 1
t_old = state.transition
adaptor = state.adaptor
κ = state.κ
metric = state.metric
# Reconstruct hamiltonian.
h = Hamiltonian(metric, model)
# Compute gradient of log density.
logdensity_and_gradient = Base.Fix1(
LogDensityProblems.logdensity_and_gradient, model.logdensity
)
θ = copy(t_old.z.θ)
grad = last(logdensity_and_gradient(θ))
# Update latent variables and velocity according to
# equation (15) of Chen et al. (2014)
v = state.velocity
η = spl.learning_rate
α = spl.momentum_decay
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
θ .+= newv
# Make new transition.
z = phasepoint(h, θ, v)
t = transition(rng, h, κ, z)
# Adapt h and spl.
tstat = stat(t)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))
# Compute next sample and state.
sample = Transition(t.z, tstat)
newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv)
return sample, newstate
end
I made three updates:
- copy
θ - update
θuse thenewv - create
tuse the updatedθ
I was a bit uneasy about the inplace update, thinking there might be some mismatch between the parameters and logp. So I made the above changes and I think they at least make the test case correct.
I know the logic is copied from Turing.jl and the order of updating v and θ was okay as it was. But I have to switch the order to make it work here. Ideally I can get to the bottom, but thought you might know better.
@sunxd3 Thanks a lot! I see what's going wrong here, it's the parameters have not been updated to the right state and the in-place changing for parameter theta is not suitable here, now the issue has been fixed!
If the CIs are all green, this PR should be ready now
Looks fine to me. Now the tests pass, I think the algorithm is very likely to be correctly implemented.
Still curious why this works (I am referring to the order or updating theta first).
Still curious why this works (I am referring to the order or updating theta first).
While the order of updating theta is different in AHMC and Turing, it seems they both work fine? But it seems the Turing one should be correct, though, in each step, theta is only updated using velocity from the previous step.
That is what confuses me, because in this PR, the order can't be reversed. (I think it will fail if we use the old velocity.)
Tag @yebai for knowledge and review. I think this is ready (certainty of correctness is quite high but not 100 percent).
@ErikQQY version bump maybe? (We would need to bump minor version because SGHMC will be exported. Probably lump several changes together or do something like https://github.com/TuringLang/Turing.jl/pull/2517)
Probably bump several changes together or do something like https://github.com/TuringLang/Turing.jl/pull/2517)
I think there might be some methods ambiguity when using AdvancedHMC inside Turing, since they both export SGHMC, we need to take care of that.
Codecov Report
Attention: Patch coverage is 92.00000% with 4 lines in your changes missing coverage. Please review.
Project coverage is 76.09%. Comparing base (
5a562e0) to head (7f0c811).
| Files with missing lines | Patch % | Lines |
|---|---|---|
| src/abstractmcmc.jl | 91.11% | 4 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #428 +/- ##
==========================================
+ Coverage 75.44% 76.09% +0.64%
==========================================
Files 21 21
Lines 1230 1280 +50
==========================================
+ Hits 928 974 +46
- Misses 302 306 +4
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.