Turing.jl
Turing.jl copied to clipboard
Issues with constrained parameters depending on each other
Problem
julia> using Turing
julia> @model function buggy_model()
lb ~ Uniform(0, 0.1)
ub ~ Uniform(0.11, 0.2)
x ~ transformed(Normal(0, 1), inverse(Bijectors.Logit(lb, ub)))
end
buggy_model (generic function with 2 methods)
julia> model = buggy_model();
julia> chain = sample(model, NUTS(), 1000);
┌ Info: Found initial step size
└ ϵ = 3.2
Sampling 100%|█████████████████████████████████████████████████████████████████████████████| Time: 0:00:01
julia> results = generated_quantities(model, chain); # (×) Breaks!
ERROR: DomainError with -0.05206647177072762:
log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).
DomainError detected in the user `f` function. This occurs when the domain of a function is violated.
For example, `log(-1.0)` is undefined because `log` of a real number is defined to only output real
numbers, but `log` of a negative number is complex valued and therefore Julia throws a DomainError
by default. Cases to be aware of include:
* `log(x)`, `sqrt(x)`, `cbrt(x)`, etc. where `x<0`
* `x^y` for `x<0` floating point `y` (example: `(-1.0)^(1/2) == im`)
...
In contrast, if we use Prior
to sample, we're good:
julia> chain_prior = sample(model, Prior(), 1000);
Sampling 100%|█████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
julia> results_prior = generated_quantities(model, chain_prior); # (✓) Works because no linking needed
The issue is caused by the fact that we use DynamicPPL.invlink!!(varinfo, model)
when constructing a transition
, which is what ends up in the chain rather than an issue with the inference itself.
For example, if we use AdvancedHMC.jl directly:
julia> using AdvancedHMC: AdvancedHMC
julia> f = DynamicPPL.LogDensityFunction(model);
julia> DynamicPPL.link!!(f.varinfo, f.model);
julia> chain_ahmc = sample(f, AdvancedHMC.NUTS(0.8), 1000);
[ Info: Found initial step size 3.2
Sampling 100%|███████████████████████████████| Time: 0:00:00
iterations: 1000
ratio_divergent_transitions: 0.0
ratio_divergent_transitions_during_adaption: 0.0
n_steps: 7
is_accept: true
acceptance_rate: 0.7879658455930968
log_density: -5.038135476673508
hamiltonian_energy: 7.775565727543868
hamiltonian_energy_error: -0.11294798909710124
max_hamiltonian_energy_error: 0.5539216379943772
tree_depth: 3
numerical_error: false
step_size: 1.1685229504528063
nom_step_size: 1.1685229504528063
is_adapt: false
mass_matrix: DiagEuclideanMetric([1.0, 1.0, 1.0])
julia> function to_constrained(θ)
lb = inverse(Bijectors.Logit(0.0, 0.1))(θ[1])
ub = inverse(Bijectors.Logit(0.11, 0.2))(θ[2])
x = inverse(Bijectors.Logit(lb, ub))(θ[3])
return [lb, ub, x]
end
to_constrained (generic function with 1 method)
julia> chain_ahmc_constrained = mapreduce(hcat, chain_ahmc) do t
to_constrained(t.z.θ)
end;
julia> chain_ahmc = Chains(
permutedims(chain_ahmc_constrained),
[:lb, :ub, :x]
);
Visualizing the densities of the resulting chains, we also see that the one from Turing.NUTS
is incorrect (the blue line), while the other two (Prior
and AdvancedHMC.NUTS
) coincide:
Solution?
Fixing this I think will actually be quite annoying :confused: But I do think it's worth doing.
There are a few approaches:
- Re-evaluate the model for every transition we end up accepting to get the distributions corresponding to that particular realization.
- Double the memory usage of
VarInfo
and always store both the linked and the invlinked realizations. - Use a separate context to capture the invlinked realizations.
No matter how we do this, there is the issue that we can't support this properly for externalsampler
, etc. that uses the LogDensityFunction
, without explicit re-evaluation of the model :confused: Though it seems it would still be worth adding proper support for this in the "internal" impls of the samplers
Might be worth providing an option to force re-evaluation in combination with, say, a warning if we notice that supports change between two different realizations
@yebai @devmotion @sunxd3
This can be resolved with something like https://github.com/TuringLang/DynamicPPL.jl/pull/588 + some minor changes to Turing.Inference.getparams
by turning
# Extract parameter values in a simple form from the invlinked `VarInfo`.
DynamicPPL.values_as(DynamicPPL.invlink(vi, model), OrderedDict)
into
vals = if DynamicPPL.is_static(model)
# Extract parameter values in a simple form from the invlinked `VarInfo`.
DynamicPPL.values_as(DynamicPPL.invlink(vi, model), OrderedDict)
else
# Re-evaluate the model completely to get invlinked parameters since
# we can't trust the invlinked `VarInfo` to be up-to-date.
extract_realizations(model, deepcopy(vi))
end
This then defaults to the "make sure we're doing everything correctly"-approach, but allows the user to avoid all the additional model evaluations by just doing:
model = DynamicPPL.mark_as_static(model)
before passing model
to sample
As noted in the PR, we probably should have something a bit more general to also capture when we need to use a fully blown UntypedVarInfo
to allow arbitrary number of parameters + changing between evaluations, but given that we will (soonTM) have a more flexible approach to UntypedVarInfo
which can grow arbitrarily (https://github.com/TuringLang/DynamicPPL.jl/pull/555) this might not be so important.
Combining aforementioned PRs + https://github.com/TuringLang/DynamicPPL.jl/pull/540, I imagine putting something like the following in our sample
:
if DynamicPPL.has_static_constraints(model)
model = DynamicPPL.mark_as_static(model)
end
and then continue business as usual. It will be a heuristic ofc, but will work very well in practice. Could make this a keyword argument to allow it to be disabled.
Just to clarify for my understanding: this seems to be a VarInfo issue -- because distributions in metadata is evaluated and saved. Then they are used during invlink, which means when using VarInfo, we always assume the distribution type and support are consistent?
Then how about with SimpleVarinfo? And would directing user to use SimpleVarinfo an option for solution? (Of course still need utility to check if static)
It's indeed a VarInfo
issue, but the way we do it with SimpleVarInfo
is exactly to re-evaluate the model :shrug: So what I'm suggesting is to always do this by default, because that will always produce the right thing, and then allow specialization in the cases where it makes sense
A better fix to to remove VarInfo
in favour of (generalised) SimpleVarInfo
if that is the case!
A better fix to to remove VarInfo in favour of (generalised) SimpleVarInfo if that is the case!
But that doesn't address the issue!
It's not "a bug" per-se on the side of VarInfo
, it's a question of whether we have to re-evaluate the model to get the correct transform or not. Replacing VarInfo
completely with SimpleVarInfo
would just force us to always re-evaluate the model fully whenever we want to invlink (which is obviously very undesirable in many cases).
EDIT: Even the new VarNameVector
I've been working on will also suffer from the same issue.
For example, a good use-case is https://github.com/TuringLang/Turing.jl/pull/2099 where we need to (inv)link
all the time to go between constrained (used to condition
) and unconstrained (used in sampling). We really shouldn't be doing this through re-evaluation of the model unless needed :confused:
fixed by https://github.com/TuringLang/Turing.jl/pull/2202#event-12733131234