AdvancedHMC.jl icon indicating copy to clipboard operation
AdvancedHMC.jl copied to clipboard

non-boolean progress not supported

Open sethaxen opened this issue 4 months ago • 3 comments

When sampling using the AbstractMCMC interface (as documented at https://turinglang.org/AdvancedHMC.jl/stable/get_started/#Using-the-AbstractMCMC-Interface), the Symbol values of progress (e.g. :perchain) described at https://turinglang.org/AbstractMCMC.jl/stable/api/#Progress-logging are not supported. Here's a MWE:

using AdvancedHMC, ForwardDiff, LogDensityProblems, LinearAlgebra, AbstractMCMC, LogDensityProblemsAD

struct LogTargetDensity
    dim::Int
end
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2  # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity) = p.dim
function LogDensityProblems.capabilities(::Type{LogTargetDensity})
    return LogDensityProblems.LogDensityOrder{0}()
end

ℓπ = LogTargetDensity(10)
model = AbstractMCMC.LogDensityModel(LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ))

n_samples, n_adapts = 2_000, 1_000
sampler = NUTS(0.8)

# sample with progress=true
AbstractMCMC.sample(model, sampler, n_adapts + n_samples; n_adapts, progress=true)  # works
# sample with progress=:perchain
AbstractMCMC.sample(model, sampler, n_adapts + n_samples; n_adapts, progress=:perchain)  # errors
ERROR: TypeError: non-boolean (Symbol) used in boolean context
Stacktrace:
 [1] sample(model_or_logdensity::AbstractMCMC.LogDensityModel{…}, sampler::NUTS{…}, N_or_isdone::Int64; kwargs::@Kwargs{…})
   @ AbstractMCMC ~/.julia/packages/AbstractMCMC/mcqES/src/sample.jl:23
 [2] top-level scope
   @ REPL[38]:2
Some type information was truncated. Use `show(err)` to see complete types.

Presumably this happens because lines like this constrain the type of progress to be a Bool: https://github.com/TuringLang/AdvancedHMC.jl/blob/dc8dc1cab540fbd3224bc934bf46d27bf0e48d4b/src/abstractmcmc.jl#L59

However, even if that wasn't the cause, the following lines override progress if the user didn't provide a callback (no clue why this is the case): https://github.com/TuringLang/AdvancedHMC.jl/blob/dc8dc1cab540fbd3224bc934bf46d27bf0e48d4b/src/abstractmcmc.jl#L72-L75

sethaxen avatar Nov 06 '25 14:11 sethaxen

the following lines override progress if the user didn't provide a callback (no clue why this is the case):

I'd assume that's because the callback my prematurely terminate sampling - but I'd think it would still be useful to show the progress bar and/or allow the callback to change the progress that's being reported.

nsiccha avatar Nov 20 '25 09:11 nsiccha

It's because HMCProgressCallback itself generates a progress bar.

penelopeysm avatar Nov 27 '25 12:11 penelopeysm

I don't really get why this stuff is placed in a callback as opposed to just being inside the definition of AbstractMCMC.step though

penelopeysm avatar Nov 27 '25 12:11 penelopeysm