moscot
moscot copied to clipboard
No convergence with new commit of ottjax in LRGW solvers unless inner_iterations=10 is set
I am working on this and I am solving problems with the new ottjax version. You can also see the tests that fail from the CI https://github.com/theislab/moscot/actions/runs/8361086332/job/22888306575#step:5:1934.
The code that fails would look something like this:
ap = (
AlignmentProblem(adata=adata_space_rotate)
.prepare(batch_key="batch")
.solve(epsilon=epsilon, alpha=alpha, rank=rank, initializer=initializer)
)
for prob_key in ap:
assert ap[prob_key].solution.rank == rank
assert ap[prob_key].solution.converged
but I noticed if I set inner_iterations=10
it converges
ap = (
AlignmentProblem(adata=adata_space_rotate)
.prepare(batch_key="batch")
.solve(epsilon=epsilon, alpha=alpha, rank=rank, initializer=initializer, inner_iterations=10)
)
for prob_key in ap:
assert ap[prob_key].solution.rank == rank
assert ap[prob_key].solution.converged
The new versions of LRGW solvers in ottjax handle inner_iterations
differently. This is one example test that fails https://github.com/theislab/moscot/blob/a5187c037f7137de7324348de886dc4ee7234000/tests/problems/space/test_alignment_problem.py#L81