moscot icon indicating copy to clipboard operation
moscot copied to clipboard

No convergence with new commit of ottjax in LRGW solvers unless inner_iterations=10 is set

Open selmanozleyen opened this issue 11 months ago • 3 comments

I am working on this and I am solving problems with the new ottjax version. You can also see the tests that fail from the CI https://github.com/theislab/moscot/actions/runs/8361086332/job/22888306575#step:5:1934.

The code that fails would look something like this:

ap = (
    AlignmentProblem(adata=adata_space_rotate)
    .prepare(batch_key="batch")
    .solve(epsilon=epsilon, alpha=alpha, rank=rank, initializer=initializer)
)

for prob_key in ap:
    assert ap[prob_key].solution.rank == rank
    assert ap[prob_key].solution.converged

but I noticed if I set inner_iterations=10 it converges

ap = (
    AlignmentProblem(adata=adata_space_rotate)
    .prepare(batch_key="batch")
    .solve(epsilon=epsilon, alpha=alpha, rank=rank, initializer=initializer, inner_iterations=10)
)

for prob_key in ap:
    assert ap[prob_key].solution.rank == rank
    assert ap[prob_key].solution.converged

The new versions of LRGW solvers in ottjax handle inner_iterations differently. This is one example test that fails https://github.com/theislab/moscot/blob/a5187c037f7137de7324348de886dc4ee7234000/tests/problems/space/test_alignment_problem.py#L81

selmanozleyen avatar Mar 20 '24 16:03 selmanozleyen