StanLogDensityProblems.jl icon indicating copy to clipboard operation
StanLogDensityProblems.jl copied to clipboard

More flexibly handle NaNs

Open nsiccha opened this issue 5 months ago • 9 comments

Would be nice to be able to provide some kind of error/NaN handler, to e.g. keep track of how often and why chains run into problematic regions.

nsiccha avatar Jul 18 '25 09:07 nsiccha

In my view it would make more sense to implement such a handler as a new struct implementing the LogDensityProblems interface in its own package. Then one could track this information for all LogDensityProblems, not just for a StanLogDensityProblem.

sethaxen avatar Jul 24 '25 08:07 sethaxen

Hm, I guess that would make sense! Any proposals for a name? 😅

nsiccha avatar Jul 24 '25 08:07 nsiccha

Depends on how you want to implement it! You could e.g. just implement an object that has a callback, which would then support both modifying how the original object works as well as storing statistics (but would require the user implement that stuff themselves, unless you wanted to provide some handy default callbacks), in which case LogDensityProblemsCallbacks might make sense.

sethaxen avatar Jul 24 '25 08:07 sethaxen

Well, I do currently have this implementation below lying around, basically doing what StanLogDensityProblems.jl currently supports but also supports printing the error msg to stdout. I guess one could also add some default that "just" stores the failed draws 🤔

struct BridgeStanProblem{E}
    model::BridgeStan.StanModel
    error_handler::E
end
function nan_on_error_handler end
function rethrow_on_error_handler end
BridgeStanProblem(p::StanProblem{M,nan_on_error}) where {M,nan_on_error} = if nan_on_error
    BridgeStanProblem(p.model, nan_on_error_handler)
else
    BridgeStanProblem(p.model, rethrow_on_error_handler)
end 
function StanLogDensityProblems.LogDensityProblems.capabilities(::Type{<:BridgeStanProblem})
    return StanLogDensityProblems.LogDensityProblems.LogDensityOrder{2}()  # can do gradient
end

function StanLogDensityProblems.LogDensityProblems.dimension(prob::BridgeStanProblem)
    return Int(BridgeStan.param_unc_num(prob.model))
end

function StanLogDensityProblems.LogDensityProblems.logdensity(
    prob::BridgeStanProblem, x
)
    m = prob.model
    z = convert(Vector{Float64}, x)
    try
        return BridgeStan.log_density(m, z)
    catch e
        return prob.error_handler(BridgeStan.log_density, m, z, e)
    end
end
function StanLogDensityProblems.LogDensityProblems.logdensity_and_gradient(
    prob::BridgeStanProblem, x
)
    m = prob.model
    z = convert(Vector{Float64}, x)
    try
        return BridgeStan.log_density_gradient(m, z)
    catch e
        return prob.error_handler(BridgeStan.log_density_gradient, m, z, e)
    end
end
rethrow_on_error_handler(f, m, z, e) = begin
    @error "Error evaluating $f for $(typeof(m)): $e"
    rethrow(e)
end
nan_on_error_handler(f::typeof(BridgeStan.log_density), m, z, e) = begin
    @warn "Error evaluating $f for $(typeof(m)): $e"
    NaN
end
nan_on_error_handler(f::typeof(BridgeStan.log_density_gradient), m, z, e) = begin
    @warn "Error evaluating $f for $(typeof(m)): $e"
    NaN, fill(NaN, length(z))
end
silent_nan(f::typeof(BridgeStan.log_density), m, z, e) = NaN
silent_nan(f::typeof(BridgeStan.log_density_gradient), m, z, e) = NaN, fill(NaN, length(z))

nsiccha avatar Jul 24 '25 08:07 nsiccha

I did use StanLogDensityProblems.jl as my template, which is why the annotation for the order is slightly weird (like in the original code) 😅 .

nsiccha avatar Jul 24 '25 08:07 nsiccha

I guess this can be easily modified to work as you proposed. Will do so at some uncertain point in the future! Probably when I'm starting a new project such that I don't have to copy the code around anymore.

nsiccha avatar Jul 24 '25 09:07 nsiccha

Ah, cool. Yeah actually if you do implement something that would allow for custom handling of errors, etc in its own package. I'd actually drop the nan_on_error field here and document usage of your package to support the same feature. I agree it's not ideal we suppress the error message entirely, but if we also accepted a user-provided logger then the code is already getting more complex than I think it should be for a very simple glue package like this.

sethaxen avatar Jul 24 '25 09:07 sethaxen

I'd actually drop the nan_on_error field here

I was gonna say that the handling of errors in the wrapped LogDensityProblems (as here) would stand in the way of LogDensityProblemsCallbacks handling them 👍

nsiccha avatar Jul 24 '25 09:07 nsiccha

I agree it's not ideal we suppress the error message entirely

I'm often quite happy with the default! I've just recently added that silent_nan handler above, because all the printing of errors got on my nerves.

nsiccha avatar Jul 24 '25 09:07 nsiccha