Cholesky numerical stability: Forward transform
This is a companion PR to #356. It attempts to solve the following issue, first reported in #279:
using Bijectors
using Distributions
θ_unconstrained = [
-1.9887091960524537,
-13.499454444466279,
-0.39328331954134665,
-4.426097270849902,
13.101175413857023,
7.66647404712346,
9.249285786544894,
4.714877413573335,
6.233118490809442,
22.28264809311481
]
n = 5
d = LKJCholesky(n, 10)
b = Bijectors.bijector(d)
b_inv = inverse(b)
θ = b_inv(θ_unconstrained)
Bijectors.logabsdetjac(b, θ)
# ERROR: DomainError with 1.0085229361957693:
# atanh(x) is only defined for |x| ≤ 1.
Introduction
The forward transform acts on an upper triangular matrix, W, which is supposed to have unit vectors for each column, i.e. sum(W[:, j] .^ 2) should be 1 for each j:
julia> s = rand(LKJCholesky(5, 1.0, 'U')).U
5×5 UpperTriangular{Float64, Matrix{Float64}}:
1.0 0.345448 -0.478 0.455158 0.385151
⋅ 0.938438 -0.331921 -0.305083 -0.0469749
⋅ ⋅ 0.813231 -0.397178 0.831726
⋅ ⋅ ⋅ 0.73621 0.0298828
⋅ ⋅ ⋅ ⋅ 0.395968
julia> [sum(s[:, i] .^ 2) for i in 1:5]
5-element Vector{Float64}:
1.0
1.0
1.0
1.0
1.0000000000000002
In the forward transform code, remainder_sq is initialised at one and then the squares of each element going down column j are successively subtracted, so remainder_sq is really a sum of squares of elements not yet seen.
https://github.com/TuringLang/Bijectors.jl/blob/b44177797cc8d3d1f09a1443350520b2ff4f4874/src/bijectors/corr.jl#L320-L331
Now, in principle, because z^2 = W[i, j]^2 / (sum of W[i:end, j]^2), there is no way that z^2 can be larger than 1.
However, because of floating point imprecisions, sometimes this isn't true. This is especially likely to happen if the last element W[j-1, j] is very small. This doesn't tend to happen when W is sampled from LKJCholesky, but it can happen when W is obtained through inverse transformation of some random unconstrained vector, as described in e.g. #279.
A proposed fix, instead of subtracting successive squares from 1, could just declare remainder_sq to be that sum:
idx = 1
@inbounds for j in 2:K
- remainder_sq = 1 - W[1, j]^2
for i in 2:(j - 1)
+ remainder_sq = sum(W[i:end, j] .^ 2)
z = W[i, j] / sqrt(remainder_sq)
y[idx] = atanh(z)
- remainder_sq -= W[i, j]^2
idx += 1
end
end
In practice, this is implemented by looping in reverse (over(j-1):-1:2) and incrementing remainder_sq following @devmotion's suggestion below. Hopefully, it can be seen from this that at each loop iteration, remainder_sq is equivalent to sum(W[i:end, j] .^ 2), which is the same as the snippet directly above.
There is also some finicky messing with indices (here, starting_idx) because we are filling in y in a backwards manner on each iteration through the loop.
- idx = 1
+ starting_idx = 1
@inbounds for j in 2:K
+ remainder_sq = W[j, j] ^ 2
- for i in 2:(j - 1)
+ for i in (j - 1):-1:2
+ remainder_sq += sum(W[i, j] .^ 2)
z = W[i, j] / sqrt(remainder_sq)
+ idx = starting_idx + i - 2
y[idx] = atanh(z)
- idx += 1
end
+ starting_idx += length((j - 1):-1:2)
end
Now, because W[i, j] ^ 2 is part of that sum, z can now no longer be larger than 1, and atanh doesn't throw a DomainError.
A final optimisation is to use asinh instead of atanh, by moving the incrementation of remainder_sq to only come after the calculation of y[idx] (see https://github.com/TuringLang/Bijectors.jl/pull/357#issuecomment-2510261859).
starting_idx = 1
@inbounds for j in 2:K
remainder_sq = W[j, j] ^ 2
for i in (j - 1):-1:2
- remainder_sq += sum(W[i, j] .^ 2)
z = W[i, j] / sqrt(remainder_sq)
idx = starting_idx + i - 2
- y[idx] = atanh(z)
+ y[idx] = asinh(z)
+ remainder_sq += sum(W[i, j] .^ 2)
end
starting_idx += length((j - 1):-1:2)
end
Results (last updated 19 June 2025)
Setup code for this comment
Setup code
using Bijectors
using LinearAlgebra
using Distributions
using Random
using Plots
using LogExpFunctions
# Existing link implementation from main branch ([email protected])
function _link_chol_lkj_from_upper_old(W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2 # {K \choose 2} free parameters
y = similar(W, N)
idx = 1
@inbounds for j in 2:K
y[idx] = atanh(W[1, j])
idx += 1
remainder_sq = 1 - W[1, j]^2
for i in 2:(j - 1)
z = W[i, j] / sqrt(remainder_sq)
y[idx] = atanh(z)
remainder_sq -= W[i, j]^2
idx += 1
end
end
return y
end
# New proposal (this PR)
function _link_chol_lkj_from_upper_new(W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2 # {K \choose 2} free parameters
y = similar(W, N)
starting_idx = 1
@inbounds for j in 2:K
y[starting_idx] = atanh(W[1, j])
starting_idx += 1
remainder_sq = W[j, j]^2
for i in (j - 1):-1:2
idx = starting_idx + i - 2
z = W[i, j] / sqrt(remainder_sq)
y[idx] = asinh(z)
remainder_sq += W[i, j]^2
end
starting_idx += length((j - 1):-1:2)
end
return y
end
function plot_maes(samples)
log_mae_old = log10.([sample[1] for sample in samples])
log_mae_new = log10.([sample[2] for sample in samples])
scatter(log_mae_old, log_mae_new, label="")
lim_min = floor(min(minimum(log_mae_old), minimum(log_mae_new)))
lim_max = ceil(max(maximum(log_mae_old), maximum(log_mae_new)))
plot!(lim_min:lim_max, lim_min:lim_max, label="y=x", color=:black)
xlabel!("log10(maximum abs error old)")
ylabel!("log10(maximum abs error new)")
end
function test_forward_bijector(f_old, f_new)
dist = LKJCholesky(5, 1.0, 'U')
Random.seed!(468)
samples = map(1:500) do _
x = rand(dist)
x_again_old = Bijectors._inv_link_chol_lkj(f_old(x.U))[1]
x_again_new = Bijectors._inv_link_chol_lkj(f_new(x.U))[1]
# Return the maximum absolute error between the original sample
# and sample after roundtrip transformation
(maximum(abs.(x.U - x_again_old)), maximum(abs.(x.U - x_again_new)))
end
return samples
end
Impacts of this change
First, let's check roundtrip transformation on typical samples from Cholesky. The numerical accuracy here is actually marginally better than the existing implementation:
julia> plot_maes(test_forward_bijector(_link_chol_lkj_from_upper_old, _link_chol_lkj_from_upper_new))
On top of that, it fixes the DomainErrors which occur with random unconstrained inputs:
julia> y = rand(Random.Xoshiro(468), 10) * 16;
julia> x = Bijectors._inv_link_chol_lkj(y)[1];
julia> y_old = _link_chol_lkj_from_upper_old(x)
ERROR: DomainError with 1.000207932997037:
atanh(x) is only defined for |x| ≤ 1.
Stacktrace:
[1] atanh_domain_error(x::Float64)
@ Base.Math ./special/hyperbolic.jl:240
[2] atanh
@ ./special/hyperbolic.jl:256 [inlined]
[3] _link_chol_lkj_from_upper_old(W::Matrix{Float64})
@ Main ./REPL[60]:12
[4] top-level scope
@ REPL[111]:1
julia> y_new = _link_chol_lkj_from_upper_new(x)
10-element Vector{Float64}:
1.7139942709891685
4.050190371709019
12.606352576618578
8.239542965781226
7.897855158738417
6.8859283584860345
7.201266901997009
4.588778566499414
5.507106236959234
11.582258611360368
Performance
julia> using Chairmarks
julia> @be (rand(LKJCholesky(5, 1.0, 'U'))) _link_chol_lkj_from_upper_old(_.U)
Benchmark: 3162 samples with 222 evaluations
min 101.914 ns (2 allocs: 144 bytes)
median 112.613 ns (2 allocs: 144 bytes)
mean 127.405 ns (2 allocs: 144 bytes, 0.12% gc time)
max 16.442 μs (2 allocs: 144 bytes, 98.35% gc time)
julia> @be (rand(LKJCholesky(5, 1.0, 'U'))) _link_chol_lkj_from_upper_new(_.U)
Benchmark: 2679 samples with 228 evaluations
min 116.044 ns (2 allocs: 144 bytes)
median 129.934 ns (2 allocs: 144 bytes)
mean 154.996 ns (2 allocs: 144 bytes, 0.18% gc time)
max 20.500 μs (2 allocs: 144 bytes, 98.65% gc time)
Accuracy on pathological samples
It turns out that if you are sampling long enough with really large values of y, you can still get numerical inaccuracies of up to 1e-3 on round-trip invlink/link (although none of them error, so this is probably a good improvement over the existing).
It should really be noted that all of these samples are very large (they're scaled up 16x) so the histogram here represents the outcomes on worst-case samples.
julia> max_err = map(1:5000) do i
y = rand(Random.Xoshiro(i), 10) * 16
x = Bijectors._inv_link_chol_lkj(y)[1]
y_new = _link_chol_lkj_from_upper_new(x)
maximum(abs.(y - y_new))
end;
julia> histogram(log10.(max_err))
(Incidentally, the asinh version performs better than atanh on this, though not shown as this comment is already long enough -- the atanh version has a larger median and has a less heavy tail on the left.)
Looping is implemented in reverse now. Had to be careful with indices but it's looking good, the original comment has been updated with benchmarks etc:)
Another point: I wonder if generally it would be better to avoid atanh completely, due to its constrained domain and its derivative 1/(1 - x^2) which might be problematic close to 1 and -1. Note that you could rewrite $$\mathrm{atanh}\left(w_{i,j} / \sqrt{\sum_{k=i}^{j} w_{k,j}^2}\right)$$ as $$\mathrm{asinh}\left(w_{i,j}/\sqrt{\sum_{k={i+1}}^j w_{k,j}^2}\right)$$.
Tried to look into the AD issues. The problem is that the new implementation uses different matrix elements to arrive at the same answer, but ForwardDiff doesn't know that $\sum_i W_{ij}^2 = 1$, so it generates a different Jacobian.
using ForwardDiff
using Bijectors
using LinearAlgebra
using Random
using LogExpFunctions
using Test: @test
dist = LKJCholesky(3, 1, 'U')
# Minimised version of new forward transform
function chol_upper_new(W::AbstractMatrix)
y1 = atanh(W[1, 2])
y2 = atanh(W[1, 3])
z = W[2, 3] / W[3, 3]
# NOTE: If we replace W[3, 3] here with sqrt(1 - W[1, 3]^2 - W[2, 3]^2) then
# the autodiff works. But the whole problem / the point of this PR was that
# this is numerically unstable for small values of W[3, 3].
y3 = asinh(z)
return [y1, y2, y3]
end
# Minimised version of old forward transform
function chol_upper_old(W::AbstractMatrix)
y1 = atanh(W[1, 2])
y2 = atanh(W[1, 3])
z = W[2, 3] / sqrt(1 - W[1, 3]^2)
y3 = atanh(z)
return [y1, y2, y3]
end
# Check that they both do the same thing, and they both do the same
# thing as Bijectors.jl forward transform – so they are the same function
for _ in 1:1000
x_new = rand(dist)
@test chol_upper_new(x_new.UL) ≈ chol_upper_old(x_new.UL) atol=1e-12
@test chol_upper_new(x_new.UL) ≈ bijector(dist)(x_new) atol=1e-12
end
# ForwardDiff gives different Jacobians though:
x = rand(Random.Xoshiro(468), dist)
J_FD_new = ForwardDiff.jacobian(chol_upper_new, x.UL)
# 3×9 Matrix{Float64}:
# 0.0 0.0 0.0 1.05605 0.0 0.0 0.0 0.0 0.0
# 0.0 0.0 0.0 0.0 0.0 0.0 1.0906 0.0 0.0
# 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.04432 2.85869
J_FD_old = ForwardDiff.jacobian(chol_upper_old, x.UL)
# 3×9 Matrix{Float64}:
# 0.0 0.0 0.0 1.05605 0.0 0.0 0.0 0.0 0.0
# 0.0 0.0 0.0 0.0 0.0 0.0 1.0906 0.0 0.0
# 0.0 0.0 0.0 0.0 0.0 0.0 -2.50772 8.86963 0.0
So it seems like the issue in the failing test might be the indexing by inds (which for an upper-triangular matrix evaluates to [4, 7, 8]):
https://github.com/TuringLang/Bijectors.jl/blob/f52a9c52ede1f43155239447601387eb1dafe394/test/transform.jl#L246-L251
With the old implementation, all the other columns are zero, but with the new one this isn't true. 😬
Edit: Fixed the test by giving ForwardDiff only a vector of the free parameters in the Cholesky factor, instead of giving it the whole matrix. It only works on 3x3 for now
Remaining failing test:
https://github.com/TuringLang/Bijectors.jl/blob/f52a9c52ede1f43155239447601387eb1dafe394/test/ad/chainrules.jl#L86-L95
Bijectors.jl documentation for PR #357 is available at: https://TuringLang.github.io/Bijectors.jl/previews/PR357/
I finally got the tests to pass, in a rather horrendous way (it's the test_rrule change that I'm not fond of; suggestions are welcome, I also posted a question about this on Slack). At this point, I'm not sure if we may as well just remove the call to test_rrule because it's so lax, or if there's a better way to test it rather than finite differencing.
Also updated all the benchmarks & plots in the top comment.
@devmotion Pinging again in case you want to have another look, but feel free not to if you're busy. Otherwise also cc @sunxd3 for a sanity check -- the top comment contains an explanation of how we arrived at these changes so hopefully that helps in reviewing.
Okay, I have a much better fix for the ChainRules test, thanks to @mcabbott who kindly explained this to me over Slack. I'm just copying the relevant parts of the conversation here since old messages get lost on Slack.
Slack messages
(py) hello! -- I have a problem where I can't coerce FiniteDifferences to behave correctly on a function which is known to have constrained inputs. Here's an example, simplified almost to the point of absurdity. Consider the functions foo and bar:
foo(x) = x[1]^2 + x[2]^2
bar(x) = 1 - x[3]^2
These are obviously different functions. However, if I enforce the constraint that any input to them must be a 3-dimensional unit vector (magnitude equal to 1)
using LinearAlgebra
x = rand(3)
xhat = x / norm(x)
@assert dot(xhat, xhat) ≈ 1
then it follows that they must be equivalent. Of course, FiniteDifferences doesn't understand this and gives me completely different gradients.
@assert foo(xhat) ≈ bar(xhat)
using FiniteDifferences
fdm = central_fdm(5, 1)
FiniteDifferences.grad(fdm, foo, xhat)
FiniteDifferences.grad(fdm, bar, xhat)
I'm mostly expecting the answer to be 'no' because I don't see any way of giving FD this extra information about the constraints, but is there any way at all of working around this?
- I mentioned FiniteDifferences specifically, but I suppose this applies to any autodiff backend really. I had to fix a ForwardDiff test for this too.
- I don't have the option of rewriting the function, in my case I specifically reimplemented foo as bar because the latter is much more numerically stable.
- I've tested the reimplementation manually quite thoroughly and am pretty sure it's correct, it's just that it messes with
ChainRulesTestUtils.test_rrulewhich attempts to compare its rrule against finite differencing. Changing the rrule implementation doesn't help because regardless of what the rrule does, FD always returns a bad adjoint. If there's another way of testing the rrule without using FD, happy to hear it!
(ma) You have an input that really lives in a submanifold of R^3. FD doesn’t know this. I think you can either enforce this by mapping any 3-vector to an element of this submanifold. Or you can project the resulting tangent vector afterwards onto the appropriate tangent plane.
First method is:
julia> ForwardDiff.gradient(x -> foo(x/norm(x)), xhat)
3-element Vector{Float64}:
0.42970249069975264
0.6047870015230632
-0.6604418223570914
julia> dot(ans, xhat)
-2.220446049250313e-16
julia> ForwardDiff.gradient(x -> bar(x/norm(x)), xhat)
3-element Vector{Float64}:
0.4297024906997529
0.6047870015230635
-0.660441822357091
Second method is
julia> g2 = ForwardDiff.gradient(foo, xhat) # wrong
3-element Vector{Float64}:
0.770227856610623
1.0840611957601172
0.0
julia> g2 - dot(g2, xhat) * xhat # projected
3-element Vector{Float64}:
0.42970249069975275
0.6047870015230632
-0.6604418223570914
Wait maybe I’m not sure what the question is. You have implemented a rule which is correct i.e. knows about this subspace. test_rrule does not, but you can simply avoid that & test manually?
(py) I guess I'm not sure how to do this in a good way. I could test that the rrule (applied to a tangent) would give the same output as if I were to manually calculate it using the implementation I have, but that would be the self-referential test that ChainRules warns me to not do.
(ma) To write tests, one path should be calling (the backward part of) your rrule on some inputs & outputs, which you generate obeying the subspace.
The other path should I think be FiniteDiff or something — CR docs are warning you that writing a 2nd implementation of the same rule is likely to have the same mistakes. But you will need to compose it with a function which puts the perturbed input back on the subspace, as in x -> foo(x/norm(x)) above.
For the general case… we meant to write a paper explaining all this, but it never got finished. Maybe https://github.com/mcabbott/OddArrays.jl is of interest to you… trying to invent harder test cases.
Sure, I think there's already an existing issue - I'll push that higher up my priority list. Thanks!