Manifolds.jl icon indicating copy to clipboard operation
Manifolds.jl copied to clipboard

Implement second order approximation to geodesic for the Connection Manifold

Open Nimrais opened this issue 9 months ago • 24 comments

Currently, the package implements the exponential map through an ODE solver (ODEExponentialRetraction), which uses OrdinaryDiffEq under the hood. This implementation is quite computationally intensive. It would be beneficial to have a more computationally efficient approximation. One popular approach, especially in information geometry, is to use a second-order approximation. For reference, see equation 16 in section 5.3 of this paper: https://proceedings.mlr.press/v119/lin20d/lin20d.pdf. This approximation is relatively easy to implement and doesn't require any heavy dependencies.

it seems not to much of work and it will come just from a careful implementation of smt like this

function exp_secondorder(
   ...
    Γ,
    p0,
    v0
)   
    Δ = similar(p0)  # Preallocate Δ with same type/size as p0
    Manifolds.@einsum Δ[k] = -0.5 * Γ[k,i,j] * v0[i] * v0[j]
    return p0 + v0 + Δ
end

I want to propose implementing this approximation for the Manifolds.jl package if you find it a reasonable proposal. I can file a PR myself.

Nimrais avatar Feb 07 '25 12:02 Nimrais

Good idea. In the context of this paper I think it would be appropriate to make a new <:AbstractRetractionMethod struct for this method. I'd suggest using the affine_connection! function instead of getting the coefficients of the Christoffel symbol, it can often be more efficient.

mateuszbaran avatar Feb 07 '25 14:02 mateuszbaran

Hi, this sounds like a very good idea! We would have to come up with a good name SecondOrderRetraction for example already has another meaning as well (see Absil, Mahony, Sepulchre p. 107).

One caveat here is, that the current formula you propose does not necessarily work for all manifolds, namely in the paper above you need the Fisher-Rao metric (which we do not have as a type in Manifolds.jl), and work in what they call natural parametrisation with respect to that.

kellertuer avatar Feb 07 '25 14:02 kellertuer

One caveat here is, that the current formula you propose does not necessarily work for all manifolds, namely in the paper above you need the Fisher-Rao metric.

Do you mean the formula above that I wrote in the issue (it seems a generic rule that they just computed for Fisher-Rao geometry), or are you talking about the equation (16) that indeed involves the Fisher information?

Nimrais avatar Feb 07 '25 14:02 Nimrais

Ah, their Equation (16) seems to be more general, sure. But I am not sure what they would require especially for v0 (in Manifolds.jl usually denoted X)? I would think it at least requires an AAbstractBasis of T_pM and you would need som coordinates of X in which you would get the christoffel symbols and then how would you add that to p (p0)?

So from the theory, you might be right, this is maybe doable relatively generic, in practice, I am not fully convinced yet.

kellertuer avatar Feb 07 '25 14:02 kellertuer

ahh, sure how would you add that to the p. Exponential family manifolds always have the ambient R^n around them, so we can always interpret it as just Euclidian +, so here you are right; the formula with the + works only for manifolds with some ambient space around them where this + indeed is defined.

Nimrais avatar Feb 07 '25 14:02 Nimrais

And for a lot of manifolds this should work, but my most prominent example are either the fixed rank matrices – represented by their three SVD factors and not their matrix in the embedding – or Grassmann, where we represent the equivalence class by one Stiefel element and + would be defined but not so clear what it is.

So in theory, this is super straight forward, if you have an embedding, in practice – we need a basis for the tangent space and might have to be a bit careful.

kellertuer avatar Feb 07 '25 14:02 kellertuer

But we don't need the ambient space here, p0 + v0 + Δ is in a parametrization.

So in theory, this is super straight forward, if you have an embedding, in practice – we need a basis for the tangent space and might have to be a bit careful.

We specifically have InducedBasis for that.

mateuszbaran avatar Feb 07 '25 16:02 mateuszbaran

I do not see how an induced Basis allows us to add the coordinates with respect to that basis (is that what you mean?) to a point. For example on the Sphere, the coordinates in the induced basis are in d-dimensions, while p0 is in d+1 dimensions. The point p0 would never be in a parametrisation?

kellertuer avatar Feb 07 '25 17:02 kellertuer

I mean you add parameters of p in the selected parametrization and coordinates of a tangent vector in the induced basis.

mateuszbaran avatar Feb 07 '25 17:02 mateuszbaran

Ah, so we need at least a chart, if not an atlas and then additionally a basis (induced by the chart) of the tangent space? Yes, that would work, though I feel our support for (concrete) atlases is also still a bit limited.

edit: but right, I was not careful enough in thinking in parametrizations. Maybe the same way we use c for coordinates (I try to unify that every now and then) we should use some letter for parameters?

kellertuer avatar Feb 07 '25 17:02 kellertuer

Yes, ideally an atlas would be needed.

I think our support for that is good enough, see for example EmbeddedTorus. We have an example of working with InducedBasis and in-chart computations in a tutorial. Let me know if there is anything specific you think is missing.

mateuszbaran avatar Feb 07 '25 17:02 mateuszbaran

I mostly ment actually available concrete charts / atlases for manifolds. The interface is fine, as far as I see, but we only have very few manifolds with an actual atlas available.

kellertuer avatar Feb 07 '25 17:02 kellertuer

Sure, we need more atlases. But IMO this is the right way of implementing this feature.

mateuszbaran avatar Feb 07 '25 17:02 mateuszbaran

Ah yes, sorry if I was unclear on that. In parameters x for p, and in coordinates c of the induced basis for the tangent vector X, Δ being similar to c and the “Christoffelistic” being with respect to that basis – this is the completely generic right approach for this. That is all the “be careful with the details” was referring to before and this should be careful enough. 👍

kellertuer avatar Feb 07 '25 17:02 kellertuer

@Nimrais Do you need some help implementing the approach I've sketched? I can implement an example for sphere with StereographicAtlas so you could just focus on extending support to the manifolds and atlases you need.

mateuszbaran avatar Feb 13 '25 10:02 mateuszbaran

@mateuszbaran Ah, I haven't started working on PR yet, so it would be cool if you could provide an example. But I planned to start from the sphere tomorrow anyway.

Nimrais avatar Feb 13 '25 10:02 Nimrais

nice! This is interesting and would be cool to have. If we feel too insecure to do that fully generic we could also first model that in ExponentialFamilyManifolds and once it has “stabilized” move it over.

Do you already have looked at the Atlas (with the exactly one chart) you would need?

kellertuer avatar Feb 13 '25 11:02 kellertuer

Here is a quick sketch:

using Manifolds
using ManifoldsBase

using Manifolds: get_chart_index, affine_connection, StereographicAtlas

using LinearAlgebra

using Plots

struct ChartSecondOrderRetraction{TA<:AbstractAtlas} <: AbstractRetractionMethod
    A::TA
end

function ManifoldsBase._retract!(M::AbstractManifold, q, p, X, m::ChartSecondOrderRetraction)
    return retract_chart_second_order!(M, q, p, X, m)
end

function Manifolds.affine_connection!(::Sphere, Zc, ::StereographicAtlas, i, a, Xc, Yc)
    # from https://math.stackexchange.com/questions/2179547/riemann-curvature-tensor-of-mathbbs-rn-with-stereographic-projection

    denom = 1 + dot(a, a)
    Zc .= (-2 * dot(a, Xc)) .* Yc
    Zc .+= (-2 * dot(a, Yc)) .* Xc
    Zc .+= 2 .* a .* Xc .* Yc
    Zc ./= denom
    return Zc
end

function retract_chart_second_order!(M::AbstractManifold, q, p, X, m::ChartSecondOrderRetraction)
    i = get_chart_index(M, m.A, p)
    p_chart = get_parameters(M, m.A, i, p)
    X_chart = get_coordinates(M, p, X, induced_basis(M, m.A, i))
    ac = affine_connection(M, m.A, i, p_chart, X_chart, X_chart)
    println(ac)
    q_chart = p_chart + X_chart - ac / 2

    # TODO: check if q_chart is still in the domain of the chart; if not, either trim or throw an error?
    return get_point!(M, q, m.A, i, q_chart)
end

struct ChartFirstOrderRetraction{TA<:AbstractAtlas} <: AbstractRetractionMethod
    A::TA
end

function ManifoldsBase._retract!(M::AbstractManifold, q, p, X, m::ChartFirstOrderRetraction)
    return retract_chart_first_order!(M, q, p, X, m)
end

function retract_chart_first_order!(M::AbstractManifold, q, p, X, m::ChartFirstOrderRetraction)
    i = get_chart_index(M, m.A, p)
    p_chart = get_parameters(M, m.A, i, p)
    X_chart = get_coordinates(M, p, X, induced_basis(M, m.A, i))
    q_chart = p_chart + X_chart

    # TODO: check if q_chart is still in the domain of the chart; if not, either trim or throw an error?
    return get_point!(M, q, m.A, i, q_chart)
end

function test()
    M = Sphere(2)
    p = [0.701018775197982, 0.08599885936644822, -0.7079384669641785]
    X = [0.34932267994844124, -0.5313371180231259, 0.28136254837437347]

    ts = 10.0 .^ (-8:1:0)
    diffs_1 = similar(ts)
    diffs_2 = similar(ts)
    rm1 = ChartFirstOrderRetraction(StereographicAtlas())
    rm2 = ChartSecondOrderRetraction(StereographicAtlas())
    for (i, ti) in enumerate(ts)
        q_1 = retract(M, p, ti * X, rm1)
        q_2 = retract(M, p, ti * X, rm2)
        q_ref = exp(M, p, ti * X)
        diffs_1[i] = max(eps(), distance(M, q_1, q_ref))
        diffs_2[i] = max(eps(), distance(M, q_2, q_ref))
    end

    xlims = (minimum(ts), maximum(ts))
    fig = plot(ts, diffs_1, xscale=:log10, xlims=xlims, yscale=:log10, label="first order retraction")
    plot!(fig, ts, diffs_2, xscale=:log10, xlims=xlims, yscale=:log10, label="second order retraction")
    xlabel!(fig, "t")
    ylabel!(fig, "error")
    return fig
end

test()

affine_connection is incorrect and needs to be fixed but I've left it to show a general idea how to do it. This code also currently doesn't handle well going outside chart domain. And there is also a small demonstration how to check if this works, though maybe it's not ideal because the "first order" retraction appears to actually have second order convergence in the plot.

I hope this works as a starting point 🙂 .

mateuszbaran avatar Feb 13 '25 14:02 mateuszbaran

@mateuszbaran Thanks! It's a clear example; I will play with it with another manifold (for example, torus). I just want to get used to working in charts with Manifolds.jl.

nice! This is interesting and would be cool to have. If we feel too insecure to do that fully generic we could also first model that in ExponentialFamilyManifolds and once it has “stabilized” move it over. Do you already have looked at the Atlas (with the exactly one chart) you would need?

@kellertuer I think from my POV, it would be easier to start on implementation from ExponentialFamilyManifolds, so if you are fine with that, I can implement the @mateuszbaran approach for ExponentialFamilyManifolds because I think I somewhat can generically define an atlas, the affine connection. I can ping one of you to review my PR there. It will need some refactoring from my side of the package, but it looks like it needs to be improved.

Nimrais avatar Feb 14 '25 11:02 Nimrais

For me that sounds like a good approach – maybe after a bit of use you also notice small improvements you could still do and once this is fixed – and also the name is a bit more stable for the retraction – we can take it over to here or even ManifoldsBase (the name that is).

So yes, if you feel that is easier to start in your package – I think that is a good idea, easier for you and leaves a bit of room (e.g. naming) when “importing” it back here.

kellertuer avatar Feb 14 '25 11:02 kellertuer

Hello @kellertuer and @mateuszbaran, I've implemented Fisher Information metrics with first and second-order retractions for the ExponentialFamilyManifolds package. You can check out the draft PR here: https://github.com/ReactiveBayes/ExponentialFamilyManifolds.jl/pull/33

My implementation uses the Fisher Information matrix as a Riemannian metric to define proper geodesics on manifolds for exponential family distributions. While I initially considered implementing chart-based atlases, I decided against it since natural parameters are embedded in Euclidean spaces, making direct computations more straightforward, it looks a bit theatrical in ExponentialFamilyManifolds.jl.

The core idea is to use Christoffel symbols derived from the Fisher Information metric to compute accurate second-order retractions. This extends my previous work on geodesics by providing a practical computational approach that approximates the exponential map.

Main Challenges

  1. Robust Computation of Affine Connections

When I can successfully compute the affine connection via Christoffel symbols, the method works beautifully. However, I'm uncertain about the most robust approach to obtain these connections when I only have access to the partial of the metric only through automatic differentiation.

  1. Testing Numerical Geodesics

Without closed-form solutions for geodesics in most spaces, I'm looking for good verification strategies. How do you typically test numerically computed geodesics when analytical solutions aren't available? I'm particularly interested in any invariants or properties I could verify.

  1. Basis Implementation for ManifoldDiff Compatibility

I'm still learning how to properly implement tangent space bases that are fully compatible with ManifoldDiff. I've encountered issues with manifolds that wrap SymmetricPositiveDefinite - the automatic differentiation doesn't always produce the expected result.

Specific Questions

  1. Product Manifold Representation Size: I noticed that representation_size returns nothing for product manifolds. This seems counterintuitive to me. Is this the expected behavior, or could it be an issue in the implementation?
M = ProductManifold(Euclidean(3), SymmetricMatrices(3)) 
@show Manifolds.representation_size(M) == nothing # true
  1. Dimension Mismatch with Hessian Computation: I'm seeing dimension inconsistencies between Fisher Information (which is the Hessian of the log-partition function) and the Hessian computed via ManifoldDiff. The gradient dimensions match perfectly (both are 12-dimensional vectors), but the Fisher Information is (12,12) while ManifoldDiff's Hessian is (9,9). Theoretically, these should be the same matrix, so this dimensional mismatch suggests something about the manifold structure isn't being properly captured.
using ExponentialFamily

using DifferentiationInterface
using ManifoldDiff
import ManifoldDiff: TangentDiffBackend
using Manifolds, FiniteDifferences, ForwardDiff
using LinearAlgebra
using RecursiveArrayTools
using Test

M = ProductManifold(Euclidean(3), SymmetricMatrices(3))

pos_def_matrix = rand(3, 3) + diagm(ones(3))
pos_def_matrix = pos_def_matrix * pos_def_matrix'
neg_pos_def_matrix = -pos_def_matrix

p = ArrayPartition(rand(3), neg_pos_def_matrix)

function create_ef(p)
    natural_parameters = ExponentialFamily.pack_parameters((Manifolds.submanifold_component(M, p, 1), Manifolds.submanifold_component(M, p, 2)))
    return ExponentialFamily.ExponentialFamilyDistribution(MvNormalMeanCovariance, natural_parameters)
end

ef = create_ef(p)

function matrix_logpartition(p)
    ef = create_ef(p)
    return logpartition(ef)
end

rb_onb_fwdd = TangentDiffBackend(AutoForwardDiff())
fd_grad = ManifoldDiff.gradient(M, matrix_logpartition, p, rb_onb_fwdd)

@test fd_grad ≈ gradlogpartition(ef) # pass 12 dim vs 12 dim vectors
@test fisherinformation(ef) ≈ ManifoldDiff.hessian(M, matrix_logpartition, p, rb_onb_fwdd) # fails (12, 12) (it seems correct hessian) vs (9, 9) computed from the ManifoldDiff hessian

Packages that I am using in my environment

Status `~/repos/JuliaManifolds/Project.toml`
⌃ [a0c0ee7d] DifferentiationInterface v0.6.43
  [62312e5e] ExponentialFamily v2.0.3
  [5c9727c4] ExponentialFamilyManifolds v2.0.0 `~/repos/JuliaManifolds/ExponentialFamilyManifolds.jl`
  [17f509fa] ExponentialFamilyProjection v2.0.2 `~/repos/JuliaManifolds/ExponentialFamilyProjection.jl`
  [26cc04aa] FiniteDifferences v0.12.32
  [f6369f11] ForwardDiff v0.10.38
  [af67fdf4] ManifoldDiff v0.4.2
  [1cead3c2] Manifolds v0.10.14
  [3362f125] ManifoldsBase v1.0.1

Conclusion

I'd love to have your feedback on:

  • My overall implementation approach
  • Strategies for robust affine connection computation? fast affine connection computation?
  • Testing numerical geodesics without analytical solutions
  • The representation size and dimension mismatch issues
  • Potential to generalize this work for future inclusion in Manifolds.jl

Nimrais avatar Mar 18 '25 13:03 Nimrais

Thanks for that detailed response. I am a bit busy this wee, but concerning question one, for now that is expected behaviour, cf. https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions/#ManifoldsBase.representation_size-Tuple{AbstractManifold} – where it is documented, that it returns array dimensions when the point is represented by an array, otherwise nothing, when it is some data structure. Product manifolds have an ArrayPartition data type, hence it returns nothing. Currently Tuple return types are used to create the corresponding arrays. We would have to return something else for non-arrays.

kellertuer avatar Mar 18 '25 13:03 kellertuer

For Point 2, my first guess is that the Hessian is with respect to some tangent space basis, since manifold_dimension(M) = 9 but the product manifold can easily be embedded into $\mathbb R^3 \times \mathbb R^{3\times3}$ which has your dimension 12.

The gradient is the finite-diff gradient in the embedding with a projection? Is the gradient conversion done (see riemannian_gradien)? For both we also have check_gradient/check_hessian methods, though over in Manopt.jl https://manoptjl.org/stable/helpers/checks/

kellertuer avatar Mar 18 '25 15:03 kellertuer

This is nice progress!

  1. Robust Computation of Affine Connections

When I can successfully compute the affine connection via Christoffel symbols, the method works beautifully. However, I'm uncertain about the most robust approach to obtain these connections when I only have access to the partial of the metric only through automatic differentiation.

I've looked into it and I couldn't find any significant simplification to just following the general formula using automatic differentiation. It seems to me that there should be a better approach but it's not obvious to me how to do it. Maybe a slow generic implementation and symbolically derived special cases you need to be fast would work for you?

3. Testing Numerical Geodesics

Without closed-form solutions for geodesics in most spaces, I'm looking for good verification strategies. How do you typically test numerically computed geodesics when analytical solutions aren't available? I'm particularly interested in any invariants or properties I could verify.

The logic behind geodesic calculation is fairly simple and you can just try it for sphere or some other simple manifold with exp and see if it works using https://juliamanifolds.github.io/ManifoldsBase.jl/stable/numerical_verification/#ManifoldsBase.check_retraction .

I'd expect that the computation of Christoffel symbol to be more prone to bugs, and here checking it against some reference or standard invariants would be a good idea.

2. Dimension Mismatch with Hessian Computation: I'm seeing dimension inconsistencies between Fisher Information (which is the Hessian of the log-partition function) and the Hessian computed via ManifoldDiff. The gradient dimensions match perfectly (both are 12-dimensional vectors), but the Fisher Information is (12,12) while ManifoldDiff's Hessian is (9,9). Theoretically, these should be the same matrix, so this dimensional mismatch suggests something about the manifold structure isn't being properly captured.

How do you define Fisher information to make it (12,12) when the manifold is 9-dimensional? I'd expect the Fisher matrix to be 9x9 also.

mateuszbaran avatar Mar 20 '25 14:03 mateuszbaran