Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

Implement logabsdetjac for Inverse{<:TruncatedBijector} for better numerical stability

Open acertain opened this issue 1 year ago • 2 comments

Example of previous badness: logabsdetjac(inverse(bijector(Uniform(-1,1))), 80) = -Inf (is now -79.30685281944005)

Formula stolen from stan, I didn't check its correctness.

acertain avatar Aug 17 '24 22:08 acertain

@devmotion thanks for the suggestions, I've implemented them

acertain avatar Aug 19 '24 21:08 acertain

@acertain this seems to break a correctness test for Julia 1.6. Do you want to take a look?

yebai avatar Aug 29 '24 10:08 yebai

We dropped 1.6 already, and the remaining tests pass (known Enzyme crashes notwithstanding).

I've bumped the minor version since this technically exposes a new function.


I'm wondering how we can implement a good test for this. We can't do this for example

d = Uniform(-1, 1); b = bijector(d); y = 80; x = inverse(b)(y)

@test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true)

Thanks to this PR the LHS evaluates to a finite value, but the comparison fails because because x evaluates to 1.0 and logpdf_with_trans returns infinite. (In theory, x should be 0.9999999....... but floats aren't precise enough to capture that.)

Maybe we just check isfinite() for the LHS terms? Or we could do the equality check iff (b ∘ inverse(b))(y) == y?

The existing tests already check for numerical accuracy on a wide range of distributions and non-pathological values, so we should already be fairly confident that this PR does not cause any regressions on these.

https://github.com/TuringLang/Bijectors.jl/blob/d342371da2ab090b5b519265c1951c9322a39879/test/interface.jl#L78-L85

penelopeysm avatar Nov 30 '24 00:11 penelopeysm

Probably easiest to compare with BigFloats. E.g., on the master branch the example above gives

julia> using Bijectors

julia> d = Uniform(big(-1.0), big(1.0));

julia> b = bijector(d);

julia> y = big(80.0);

julia> x = inverse(b)(y)
0.9999999999999999999999999999999999639029722430916965537574328529994522280756386

julia> logpdf_with_trans(d, x, true)
-80.000000000000000000000000000000000036097027594005675893754704116470597853965

devmotion avatar Nov 30 '24 00:11 devmotion