AdvancedHMC.jl icon indicating copy to clipboard operation
AdvancedHMC.jl copied to clipboard

Add Stochastic Gradient HMC

Open ErikQQY opened this issue 8 months ago • 13 comments

Part of #60

ErikQQY avatar Apr 23 '25 11:04 ErikQQY

a bump on this

sunxd3 avatar Apr 30 '25 21:04 sunxd3

AdvancedHMC.jl documentation for PR #428 is available at: https://TuringLang.github.io/AdvancedHMC.jl/previews/PR428/

github-actions[bot] avatar May 18 '25 09:05 github-actions[bot]

The basic SGHMC algorithm is intuitive, but in the AbstractMCMC tests, there are Inf in the final result after being transformed back to the original space.Gonna need some advice on what's going wrong here @sunxd3

ErikQQY avatar May 18 '25 09:05 ErikQQY

let me take a look, I haven't deal with this part of the code before, so give me bit of time.

sunxd3 avatar May 18 '25 12:05 sunxd3

The following version passes the test

function AbstractMCMC.step(
    rng::AbstractRNG,
    model::AbstractMCMC.LogDensityModel,
    spl::SGHMC,
    state::SGHMCState;
    n_adapts::Int=0,
    kwargs...,
)
    if haskey(kwargs, :nadapts)
        throw(
            ArgumentError(
                "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
            ),
        )
    end

    i = state.i + 1
    t_old = state.transition
    adaptor = state.adaptor
    κ = state.κ
    metric = state.metric

    # Reconstruct hamiltonian.
    h = Hamiltonian(metric, model)

    # Compute gradient of log density.
    logdensity_and_gradient = Base.Fix1(
        LogDensityProblems.logdensity_and_gradient, model.logdensity
    )
    θ = copy(t_old.z.θ)
    grad = last(logdensity_and_gradient(θ))

    # Update latent variables and velocity according to
    # equation (15) of Chen et al. (2014)
    v = state.velocity
    η = spl.learning_rate
    α = spl.momentum_decay
    newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
    θ .+= newv

    # Make new transition.
    z = phasepoint(h, θ, v)
    t = transition(rng, h, κ, z)

    # Adapt h and spl.
    tstat = stat(t)
    h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate)
    tstat = merge(tstat, (is_adapt=isadapted,))

    # Compute next sample and state.
    sample = Transition(t.z, tstat)
    newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv)

    return sample, newstate
end

I made three updates:

  1. copy θ
  2. update θ use the newv
  3. create t use the updated θ

I was a bit uneasy about the inplace update, thinking there might be some mismatch between the parameters and logp. So I made the above changes and I think they at least make the test case correct.

I know the logic is copied from Turing.jl and the order of updating v and θ was okay as it was. But I have to switch the order to make it work here. Ideally I can get to the bottom, but thought you might know better.

sunxd3 avatar May 21 '25 15:05 sunxd3

@sunxd3 Thanks a lot! I see what's going wrong here, it's the parameters have not been updated to the right state and the in-place changing for parameter theta is not suitable here, now the issue has been fixed!

If the CIs are all green, this PR should be ready now

ErikQQY avatar May 22 '25 04:05 ErikQQY

Looks fine to me. Now the tests pass, I think the algorithm is very likely to be correctly implemented.

Still curious why this works (I am referring to the order or updating theta first).

sunxd3 avatar May 22 '25 07:05 sunxd3

Still curious why this works (I am referring to the order or updating theta first).

While the order of updating theta is different in AHMC and Turing, it seems they both work fine? But it seems the Turing one should be correct, though, in each step, theta is only updated using velocity from the previous step.

ErikQQY avatar May 22 '25 14:05 ErikQQY

That is what confuses me, because in this PR, the order can't be reversed. (I think it will fail if we use the old velocity.)

sunxd3 avatar May 22 '25 15:05 sunxd3

Tag @yebai for knowledge and review. I think this is ready (certainty of correctness is quite high but not 100 percent).

sunxd3 avatar May 26 '25 09:05 sunxd3

@ErikQQY version bump maybe? (We would need to bump minor version because SGHMC will be exported. Probably lump several changes together or do something like https://github.com/TuringLang/Turing.jl/pull/2517)

sunxd3 avatar May 26 '25 09:05 sunxd3

Probably bump several changes together or do something like https://github.com/TuringLang/Turing.jl/pull/2517)

I think there might be some methods ambiguity when using AdvancedHMC inside Turing, since they both export SGHMC, we need to take care of that.

ErikQQY avatar May 26 '25 16:05 ErikQQY

Codecov Report

Attention: Patch coverage is 92.00000% with 4 lines in your changes missing coverage. Please review.

Project coverage is 76.09%. Comparing base (5a562e0) to head (7f0c811).

Files with missing lines Patch % Lines
src/abstractmcmc.jl 91.11% 4 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #428      +/-   ##
==========================================
+ Coverage   75.44%   76.09%   +0.64%     
==========================================
  Files          21       21              
  Lines        1230     1280      +50     
==========================================
+ Hits          928      974      +46     
- Misses        302      306       +4     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Jun 23 '25 17:06 codecov[bot]