Bisection with ForwardDiff
Hi!
The changes in Roots 2 broke root finding with non-standard number types in Bijectors: https://github.com/TuringLang/Bijectors.jl/pull/219
The reason seems to be that in Roots < 2 find_zero(f, (a, b)) was guaranteed to use zero tolerances but picked a different algorithm A42 for non-FloatXY number types. Whereas now for such types Bijection will be used with non-zero tolerances. In Bijectors we don't deal with BigFloats (that the deprecation warning refers to) but with e.g. Dual{Float64}, TrackedReal{Float64} etc., i.e., non-standard number types that are still based on FloatXY though. It seems the different tolerances causes our tests to fail.
Of course, we would like to use the same tolerances for all these standard and non-standard FloatXY types. The new default behaviour seems a bit unfortunate but I guess there was a good reason for changing it. What's the best way to ensure from our side that always the same tolerances are used and the results are consistent? Should we set the tolerances to zero explicitly? From what I understand this might be inefficient though if someone actually uses BigFloat. Or do you recommend choosing A42 instead of Bisection, which was used for BigFloat before?
Oops. The main reason was the choice of algorithm was behind the scenes, as it depended on the number type, and was made more explicit. I'd hope to change Roots, if I can, to make this work as before for your package. But will have to see if it can be done. Thanks for letting me know about this.
Hmm, the simplest thing would be to use A42 in all the cases. The default is set up to be the more robust (For Float64 it can work through infinities in the function and/or endpoints), but A42 usually converges in just a few steps compared to bisection which can need many iterations with some input values and the BigFloat input types. As far as I can tell, for each of your test cases the difference between bisection and A42 with the default tolerances is either less or equal to2eps(alpha), as expected (which holds for all the examples with the types that are erroring), or the two different values are both exact zeros.
That advice aside, I'm really unclear why this slight difference propagates upward. I'll have to dig a bit more into that. The examples with the Dual type have big differences in the other components.
The bisection carries extra values in the second and third component of Dual.
Do you happen to have a case that you know fails under the newer Roots. I can dig through the test cases in tests/interface.jl but if it were easy for you to point me to a problem, that would be helpful.
Do you happen to have a case that you know fails under the newer Roots. I can dig through the test cases in
tests/interface.jlbut if it were easy for you to point me to a problem, that would be helpful.
Is the following simple enough or should I dig deeper?
With Roots 1.4.1
julia> using Bijectors, ForwardDiff, FiniteDifferences
julia> FiniteDifferences.jacobian(central_fdm(5, 1), x -> Bijectors.find_alpha(x[1], x[2], x[3]), [3.1, 10.2, 4.3])
([0.14500346236335704 -0.09418785126629267 -0.8549965376369001],)
julia> ForwardDiff.gradient(x -> Bijectors.find_alpha(x[1], x[2], x[3]), [3.1, 10.2, 4.3])
3-element Vector{Float64}:
0.14500346236323014
-0.09418785126627265
-0.8549965376367699
With Roots 2.0.0
julia> using Bijectors, ForwardDiff, FiniteDifferences
julia> FiniteDifferences.jacobian(central_fdm(5, 1), x -> Bijectors.find_alpha(x[1], x[2], x[3]), [3.1, 10.2, 4.3])
([0.14500346236335704 -0.09418785126629267 -0.8549965376369001],)
julia> ForwardDiff.gradient(x -> Bijectors.find_alpha(x[1], x[2], x[3]), [3.1, 10.2, 4.3])
3-element Vector{Float64}:
0.20378439032586593
-0.40756878065173185
0.0
I'm not sure why but it seems there's some issue with dual numbers in Roots 2.0.0.
Awesome, thanks! I'll see what I can track down.
#315 addresses this by reverting to the old behavior, but the underlying issue remains.
There are issues with Bisection with ForwardDiff that seem to depend on the number of iterations employed. In the modified example below, basic_bisection seems fine when the number of steps is capped at 16, but not 17:
using Revise
using Roots
using ForwardDiff, FiniteDifferences
function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real}
# Compute the initial bracket (see above).
Δ = 2 * abs(wt_u_hat)
lower = float(wt_y - Δ)
upper = float(wt_y + Δ)
# Handle empty brackets (https://github.com/TuringLang/Bijectors.jl/issues/204)
if lower == upper
return lower
end
f(α) = α + wt_u_hat * tanh(α + b) - wt_y
#return find_zero(f, (lower, upper), Bisection())
#return basic_bisection(f, lower, upper; n = 17)
return find_zero(f, (lower, upper))
end
# basic bisection
function basic_bisection(f, u, v;n=17)
fu,fv = f(u), f(v)
a0 = u
ctr = 1
while ctr < n #abs(v-u) >= 4eps(u)
ctr += 1
a0 = u/2 + v/2
#return a0
f0 = f(a0)
iszero(f0) && return a0
if sign(f0) * sign(fu) < 0
v,fv = a0, f0
else
u,vu = a0, f0
end
end
return a0
end
function main()
u = FiniteDifferences.jacobian(central_fdm(5, 1), x -> find_alpha(x[1], x[2], x[3]), [3.1, 10.2, 4.3])
v = ForwardDiff.gradient(x -> find_alpha(x[1], x[2], x[3]), [3.1, 10.2, 4.3])
[u[1]' v]
end
In the modified example below,
basic_bisectionseems fine when the number of steps is capped at 16, but not 17:
Ooh, that's very weird. Ad hoc, I don't know what could cause this behaviour.
#315 addresses this by reverting to the old behavior,
It seems that is sufficient to fix the example above at least. With Roots 2.0.1 I get
julia> using Bijectors, ForwardDiff, FiniteDifferences
julia> FiniteDifferences.jacobian(central_fdm(5, 1), x -> Bijectors.find_alpha(x[1], x[2], x[3]), [3.1, 10.2, 4.3])
([0.14500346236335704 -0.09418785126629267 -0.8549965376369001],)
julia> ForwardDiff.gradient(x -> Bijectors.find_alpha(x[1], x[2], x[3]), [3.1, 10.2, 4.3])
3-element Vector{Float64}:
0.14500346236323014
-0.09418785126627265
-0.8549965376367699
This does not fix the issue with basic_bisection above but maybe we should also define a dedicated forward-mode rule for ForwardDiff in Bijectors that uses the partial derivatives directly (like we do for ChainRules in https://github.com/TuringLang/Bijectors.jl/blob/f987c1aa7fc5d402ba3fa34ca617be5ccef0e290/src/chainrules.jl#L2-L8). That would circumvent the ForwardDiff issues on our side.
Ref: https://github.com/TuringLang/Bijectors.jl/pull/219/commits/3641c0ae342476c235ff381cd71b308f1cba18ef (works also with Roots 2.0.0)
Closing. PR #329 added some documentation about how Bisection() is not amenable to this approach, and an alternative is suggested in the documentation. Thanks for the report. If you have a link to how to define "a forward-mode rule for ForwardDiff" that would be appreciated. I added rrule and frules in PR #326.
Thanks, I'll close this issue then :slightly_smiling_face: