Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

Cholesky numerical stability: Forward transform

Open penelopeysm opened this issue 1 year ago • 7 comments

This is a companion PR to #356. It attempts to solve the following issue, first reported in #279:

using Bijectors
using Distributions

θ_unconstrained = [
	-1.9887091960524537,
	-13.499454444466279,
	-0.39328331954134665,
	-4.426097270849902,
	13.101175413857023,
	7.66647404712346,
	9.249285786544894,
	4.714877413573335,
	6.233118490809442,
	22.28264809311481
]
n = 5
d = LKJCholesky(n, 10)
b = Bijectors.bijector(d)
b_inv = inverse(b)

θ = b_inv(θ_unconstrained)
Bijectors.logabsdetjac(b, θ)

# ERROR: DomainError with 1.0085229361957693:
# atanh(x) is only defined for |x| ≤ 1.

Introduction

The forward transform acts on an upper triangular matrix, W, which is supposed to have unit vectors for each column, i.e. sum(W[:, j] .^ 2) should be 1 for each j:

julia> s = rand(LKJCholesky(5, 1.0, 'U')).U
5×5 UpperTriangular{Float64, Matrix{Float64}}:
 1.0  0.345448  -0.478      0.455158   0.385151
  ⋅   0.938438  -0.331921  -0.305083  -0.0469749
  ⋅    ⋅         0.813231  -0.397178   0.831726
  ⋅    ⋅          ⋅         0.73621    0.0298828
  ⋅    ⋅          ⋅          ⋅         0.395968

julia> [sum(s[:, i] .^ 2) for i in 1:5]
5-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
 1.0000000000000002

In the forward transform code, remainder_sq is initialised at one and then the squares of each element going down column j are successively subtracted, so remainder_sq is really a sum of squares of elements not yet seen.

https://github.com/TuringLang/Bijectors.jl/blob/b44177797cc8d3d1f09a1443350520b2ff4f4874/src/bijectors/corr.jl#L320-L331

Now, in principle, because z^2 = W[i, j]^2 / (sum of W[i:end, j]^2), there is no way that z^2 can be larger than 1.

However, because of floating point imprecisions, sometimes this isn't true. This is especially likely to happen if the last element W[j-1, j] is very small. This doesn't tend to happen when W is sampled from LKJCholesky, but it can happen when W is obtained through inverse transformation of some random unconstrained vector, as described in e.g. #279.

A proposed fix, instead of subtracting successive squares from 1, could just declare remainder_sq to be that sum:

    idx = 1
    @inbounds for j in 2:K
-       remainder_sq = 1 - W[1, j]^2
        for i in 2:(j - 1)
+           remainder_sq = sum(W[i:end, j] .^ 2)
            z = W[i, j] / sqrt(remainder_sq)
            y[idx] = atanh(z)
-           remainder_sq -= W[i, j]^2
            idx += 1
        end
    end

In practice, this is implemented by looping in reverse (over(j-1):-1:2) and incrementing remainder_sq following @devmotion's suggestion below. Hopefully, it can be seen from this that at each loop iteration, remainder_sq is equivalent to sum(W[i:end, j] .^ 2), which is the same as the snippet directly above.

There is also some finicky messing with indices (here, starting_idx) because we are filling in y in a backwards manner on each iteration through the loop.

-   idx = 1
+   starting_idx = 1
    @inbounds for j in 2:K
+       remainder_sq = W[j, j] ^ 2
-       for i in 2:(j - 1)
+       for i in (j - 1):-1:2
+           remainder_sq += sum(W[i, j] .^ 2)
            z = W[i, j] / sqrt(remainder_sq)
+           idx = starting_idx + i - 2
            y[idx] = atanh(z)
-           idx += 1
        end
+       starting_idx += length((j - 1):-1:2)
    end

Now, because W[i, j] ^ 2 is part of that sum, z can now no longer be larger than 1, and atanh doesn't throw a DomainError.

A final optimisation is to use asinh instead of atanh, by moving the incrementation of remainder_sq to only come after the calculation of y[idx] (see https://github.com/TuringLang/Bijectors.jl/pull/357#issuecomment-2510261859).

   starting_idx = 1
    @inbounds for j in 2:K
        remainder_sq = W[j, j] ^ 2
        for i in (j - 1):-1:2
-           remainder_sq += sum(W[i, j] .^ 2)
            z = W[i, j] / sqrt(remainder_sq)
            idx = starting_idx + i - 2
-           y[idx] = atanh(z)
+           y[idx] = asinh(z)
+           remainder_sq += sum(W[i, j] .^ 2)
        end
        starting_idx += length((j - 1):-1:2)
    end

Results (last updated 19 June 2025)

Setup code for this comment

Setup code
using Bijectors
using LinearAlgebra
using Distributions
using Random
using Plots
using LogExpFunctions

# Existing link implementation from main branch ([email protected])
function _link_chol_lkj_from_upper_old(W::AbstractMatrix)
    K = LinearAlgebra.checksquare(W)
    N = ((K - 1) * K) ÷ 2   # {K \choose 2} free parameters
    y = similar(W, N)
    idx = 1
    @inbounds for j in 2:K
        y[idx] = atanh(W[1, j])
        idx += 1
        remainder_sq = 1 - W[1, j]^2
        for i in 2:(j - 1)
            z = W[i, j] / sqrt(remainder_sq)
            y[idx] = atanh(z)
            remainder_sq -= W[i, j]^2
            idx += 1
        end
    end
    return y
end

# New proposal (this PR)
function _link_chol_lkj_from_upper_new(W::AbstractMatrix)
    K = LinearAlgebra.checksquare(W)
    N = ((K - 1) * K) ÷ 2   # {K \choose 2} free parameters
    y = similar(W, N)
    starting_idx = 1
    @inbounds for j in 2:K
        y[starting_idx] = atanh(W[1, j])
        starting_idx += 1
        remainder_sq = W[j, j]^2
        for i in (j - 1):-1:2
            idx = starting_idx + i - 2
            z = W[i, j] / sqrt(remainder_sq)
            y[idx] = asinh(z)
            remainder_sq += W[i, j]^2
        end
        starting_idx += length((j - 1):-1:2)
    end
    return y
end

function plot_maes(samples)
    log_mae_old = log10.([sample[1] for sample in samples])
    log_mae_new = log10.([sample[2] for sample in samples])
    scatter(log_mae_old, log_mae_new, label="")
    lim_min = floor(min(minimum(log_mae_old), minimum(log_mae_new)))
    lim_max = ceil(max(maximum(log_mae_old), maximum(log_mae_new)))
    plot!(lim_min:lim_max, lim_min:lim_max, label="y=x", color=:black)
    xlabel!("log10(maximum abs error old)")
    ylabel!("log10(maximum abs error new)")
end

function test_forward_bijector(f_old, f_new)
    dist = LKJCholesky(5, 1.0, 'U')
    Random.seed!(468)
    samples = map(1:500) do _
        x = rand(dist)
        x_again_old = Bijectors._inv_link_chol_lkj(f_old(x.U))[1]
        x_again_new = Bijectors._inv_link_chol_lkj(f_new(x.U))[1]
        # Return the maximum absolute error between the original sample
        # and sample after roundtrip transformation
        (maximum(abs.(x.U - x_again_old)), maximum(abs.(x.U - x_again_new)))
    end
    return samples
end

Impacts of this change

First, let's check roundtrip transformation on typical samples from Cholesky. The numerical accuracy here is actually marginally better than the existing implementation:

julia> plot_maes(test_forward_bijector(_link_chol_lkj_from_upper_old, _link_chol_lkj_from_upper_new))
Screenshot 2025-06-19 at 01 04 29

On top of that, it fixes the DomainErrors which occur with random unconstrained inputs:

julia> y = rand(Random.Xoshiro(468), 10) * 16;

julia> x = Bijectors._inv_link_chol_lkj(y)[1];

julia> y_old = _link_chol_lkj_from_upper_old(x)
ERROR: DomainError with 1.000207932997037:
atanh(x) is only defined for |x| ≤ 1.
Stacktrace:
 [1] atanh_domain_error(x::Float64)
   @ Base.Math ./special/hyperbolic.jl:240
 [2] atanh
   @ ./special/hyperbolic.jl:256 [inlined]
 [3] _link_chol_lkj_from_upper_old(W::Matrix{Float64})
   @ Main ./REPL[60]:12
 [4] top-level scope
   @ REPL[111]:1

julia> y_new = _link_chol_lkj_from_upper_new(x)
10-element Vector{Float64}:
  1.7139942709891685
  4.050190371709019
 12.606352576618578
  8.239542965781226
  7.897855158738417
  6.8859283584860345
  7.201266901997009
  4.588778566499414
  5.507106236959234
 11.582258611360368

Performance

julia> using Chairmarks

julia> @be (rand(LKJCholesky(5, 1.0, 'U'))) _link_chol_lkj_from_upper_old(_.U)
Benchmark: 3162 samples with 222 evaluations
 min    101.914 ns (2 allocs: 144 bytes)
 median 112.613 ns (2 allocs: 144 bytes)
 mean   127.405 ns (2 allocs: 144 bytes, 0.12% gc time)
 max    16.442 μs (2 allocs: 144 bytes, 98.35% gc time)

julia> @be (rand(LKJCholesky(5, 1.0, 'U'))) _link_chol_lkj_from_upper_new(_.U)
Benchmark: 2679 samples with 228 evaluations
 min    116.044 ns (2 allocs: 144 bytes)
 median 129.934 ns (2 allocs: 144 bytes)
 mean   154.996 ns (2 allocs: 144 bytes, 0.18% gc time)
 max    20.500 μs (2 allocs: 144 bytes, 98.65% gc time)

Accuracy on pathological samples

It turns out that if you are sampling long enough with really large values of y, you can still get numerical inaccuracies of up to 1e-3 on round-trip invlink/link (although none of them error, so this is probably a good improvement over the existing).

It should really be noted that all of these samples are very large (they're scaled up 16x) so the histogram here represents the outcomes on worst-case samples.

julia> max_err = map(1:5000) do i
           y = rand(Random.Xoshiro(i), 10) * 16
           x = Bijectors._inv_link_chol_lkj(y)[1]
           y_new = _link_chol_lkj_from_upper_new(x)
           maximum(abs.(y - y_new))
       end;

julia> histogram(log10.(max_err))
Screenshot 2025-06-19 at 01 20 01

(Incidentally, the asinh version performs better than atanh on this, though not shown as this comment is already long enough -- the atanh version has a larger median and has a less heavy tail on the left.)

penelopeysm avatar Dec 01 '24 02:12 penelopeysm

Looping is implemented in reverse now. Had to be careful with indices but it's looking good, the original comment has been updated with benchmarks etc:)

penelopeysm avatar Dec 01 '24 19:12 penelopeysm

Another point: I wonder if generally it would be better to avoid atanh completely, due to its constrained domain and its derivative 1/(1 - x^2) which might be problematic close to 1 and -1. Note that you could rewrite $$\mathrm{atanh}\left(w_{i,j} / \sqrt{\sum_{k=i}^{j} w_{k,j}^2}\right)$$ as $$\mathrm{asinh}\left(w_{i,j}/\sqrt{\sum_{k={i+1}}^j w_{k,j}^2}\right)$$.

devmotion avatar Dec 01 '24 21:12 devmotion

Tried to look into the AD issues. The problem is that the new implementation uses different matrix elements to arrive at the same answer, but ForwardDiff doesn't know that $\sum_i W_{ij}^2 = 1$, so it generates a different Jacobian.

using ForwardDiff
using Bijectors
using LinearAlgebra
using Random
using LogExpFunctions
using Test: @test

dist = LKJCholesky(3, 1, 'U')

# Minimised version of new forward transform
function chol_upper_new(W::AbstractMatrix)
    y1 = atanh(W[1, 2])
    y2 = atanh(W[1, 3])
    z = W[2, 3] / W[3, 3]
    # NOTE: If we replace W[3, 3] here with sqrt(1 - W[1, 3]^2 - W[2, 3]^2) then
    # the autodiff works. But the whole problem / the point of this PR was that
    # this is numerically unstable for small values of W[3, 3].
    y3 = asinh(z)
    return [y1, y2, y3]
end

# Minimised version of old forward transform
function chol_upper_old(W::AbstractMatrix)
    y1 = atanh(W[1, 2])
    y2 = atanh(W[1, 3])
    z = W[2, 3] / sqrt(1 - W[1, 3]^2)
    y3 = atanh(z)
    return [y1, y2, y3]
end

# Check that they both do the same thing, and they both do the same
# thing as Bijectors.jl forward transform – so they are the same function
for _ in 1:1000
    x_new = rand(dist)
    @test chol_upper_new(x_new.UL) ≈ chol_upper_old(x_new.UL) atol=1e-12
    @test chol_upper_new(x_new.UL) ≈ bijector(dist)(x_new) atol=1e-12
end

# ForwardDiff gives different Jacobians though:
x = rand(Random.Xoshiro(468), dist)

J_FD_new = ForwardDiff.jacobian(chol_upper_new, x.UL)
# 3×9 Matrix{Float64}:
#  0.0  0.0  0.0  1.05605  0.0  0.0  0.0     0.0      0.0
#  0.0  0.0  0.0  0.0      0.0  0.0  1.0906  0.0      0.0
#  0.0  0.0  0.0  0.0      0.0  0.0  0.0     1.04432  2.85869

J_FD_old = ForwardDiff.jacobian(chol_upper_old, x.UL)
# 3×9 Matrix{Float64}:
#  0.0  0.0  0.0  1.05605  0.0  0.0   0.0      0.0      0.0
#  0.0  0.0  0.0  0.0      0.0  0.0   1.0906   0.0      0.0
#  0.0  0.0  0.0  0.0      0.0  0.0  -2.50772  8.86963  0.0

So it seems like the issue in the failing test might be the indexing by inds (which for an upper-triangular matrix evaluates to [4, 7, 8]):

https://github.com/TuringLang/Bijectors.jl/blob/f52a9c52ede1f43155239447601387eb1dafe394/test/transform.jl#L246-L251

With the old implementation, all the other columns are zero, but with the new one this isn't true. 😬


Edit: Fixed the test by giving ForwardDiff only a vector of the free parameters in the Cholesky factor, instead of giving it the whole matrix. It only works on 3x3 for now

penelopeysm avatar Dec 07 '24 21:12 penelopeysm

Remaining failing test:

https://github.com/TuringLang/Bijectors.jl/blob/f52a9c52ede1f43155239447601387eb1dafe394/test/ad/chainrules.jl#L86-L95

penelopeysm avatar Dec 08 '24 12:12 penelopeysm

Bijectors.jl documentation for PR #357 is available at: https://TuringLang.github.io/Bijectors.jl/previews/PR357/

github-actions[bot] avatar Jun 18 '25 22:06 github-actions[bot]

I finally got the tests to pass, in a rather horrendous way (it's the test_rrule change that I'm not fond of; suggestions are welcome, I also posted a question about this on Slack). At this point, I'm not sure if we may as well just remove the call to test_rrule because it's so lax, or if there's a better way to test it rather than finite differencing.

Also updated all the benchmarks & plots in the top comment.

@devmotion Pinging again in case you want to have another look, but feel free not to if you're busy. Otherwise also cc @sunxd3 for a sanity check -- the top comment contains an explanation of how we arrived at these changes so hopefully that helps in reviewing.

penelopeysm avatar Jun 19 '25 01:06 penelopeysm

Okay, I have a much better fix for the ChainRules test, thanks to @mcabbott who kindly explained this to me over Slack. I'm just copying the relevant parts of the conversation here since old messages get lost on Slack.

Slack messages

(py) hello! -- I have a problem where I can't coerce FiniteDifferences to behave correctly on a function which is known to have constrained inputs. Here's an example, simplified almost to the point of absurdity. Consider the functions foo and bar:

foo(x) = x[1]^2 + x[2]^2
bar(x) = 1 - x[3]^2

These are obviously different functions. However, if I enforce the constraint that any input to them must be a 3-dimensional unit vector (magnitude equal to 1)

using LinearAlgebra
x = rand(3)
xhat = x / norm(x)
@assert dot(xhat, xhat) ≈ 1

then it follows that they must be equivalent. Of course, FiniteDifferences doesn't understand this and gives me completely different gradients.

@assert foo(xhat) ≈ bar(xhat)

using FiniteDifferences
fdm = central_fdm(5, 1)
FiniteDifferences.grad(fdm, foo, xhat)
FiniteDifferences.grad(fdm, bar, xhat)

I'm mostly expecting the answer to be 'no' because I don't see any way of giving FD this extra information about the constraints, but is there any way at all of working around this?

  • I mentioned FiniteDifferences specifically, but I suppose this applies to any autodiff backend really. I had to fix a ForwardDiff test for this too.
  • I don't have the option of rewriting the function, in my case I specifically reimplemented foo as bar because the latter is much more numerically stable.
  • I've tested the reimplementation manually quite thoroughly and am pretty sure it's correct, it's just that it messes with ChainRulesTestUtils.test_rrule which attempts to compare its rrule against finite differencing. Changing the rrule implementation doesn't help because regardless of what the rrule does, FD always returns a bad adjoint. If there's another way of testing the rrule without using FD, happy to hear it!

(ma) You have an input that really lives in a submanifold of R^3. FD doesn’t know this. I think you can either enforce this by mapping any 3-vector to an element of this submanifold. Or you can project the resulting tangent vector afterwards onto the appropriate tangent plane.

First method is:

julia> ForwardDiff.gradient(x -> foo(x/norm(x)), xhat)
3-element Vector{Float64}:
  0.42970249069975264
  0.6047870015230632
 -0.6604418223570914

julia> dot(ans, xhat)
-2.220446049250313e-16

julia> ForwardDiff.gradient(x -> bar(x/norm(x)), xhat)
3-element Vector{Float64}:
  0.4297024906997529
  0.6047870015230635
 -0.660441822357091

Second method is

julia> g2 = ForwardDiff.gradient(foo, xhat)  # wrong
3-element Vector{Float64}:
0.770227856610623
1.0840611957601172
0.0

julia> g2 - dot(g2, xhat) * xhat  # projected
3-element Vector{Float64}:
 0.42970249069975275
 0.6047870015230632
-0.6604418223570914

Wait maybe I’m not sure what the question is. You have implemented a rule which is correct i.e. knows about this subspace. test_rrule does not, but you can simply avoid that & test manually?

(py) I guess I'm not sure how to do this in a good way. I could test that the rrule (applied to a tangent) would give the same output as if I were to manually calculate it using the implementation I have, but that would be the self-referential test that ChainRules warns me to not do.

(ma) To write tests, one path should be calling (the backward part of) your rrule on some inputs & outputs, which you generate obeying the subspace. The other path should I think be FiniteDiff or something — CR docs are warning you that writing a 2nd implementation of the same rule is likely to have the same mistakes. But you will need to compose it with a function which puts the perturbed input back on the subspace, as in x -> foo(x/norm(x)) above. For the general case… we meant to write a paper explaining all this, but it never got finished. Maybe https://github.com/mcabbott/OddArrays.jl is of interest to you… trying to invent harder test cases.

penelopeysm avatar Jun 19 '25 11:06 penelopeysm

Sure, I think there's already an existing issue - I'll push that higher up my priority list. Thanks!

penelopeysm avatar Jul 11 '25 10:07 penelopeysm