moscot
moscot copied to clipboard
Refactor/handle solve args
hi @MUCDK ,
So good news is we currently do a good job on partitioning the kwargs
for solve. In solve we give any kwarg
we don't know to either SinkhornSolver
or GWSolver
constructors. SinkhornSolver
uses Sinkhorn
or LRSinkhorn
from ottjax
, these classes don't have kwargs
in their constructors so when using SinkhornSolver
as a backend we are good. GWSolver
uses GromovWasserstein
or LRGromovWasserstein
from ottjax
. The parent class of these class WassersteinSolver
don't throw an error on unrecognized args. The tests will pass after the ottjax PR merges.
Here is the PR in ott-jax
: https://github.com/ott-jax/ott/pull/579
Other things done:
- Added tests if we throw appropriate errors on completely unrecognized arguments
- I refactored where we handle checks for alpha, now it's completely independent from
CompoundProblem
or any other more abstract class. It's handled inGWSolver
as it should. - I added extra tests on the errors we raise for alpha or the data given.
Additionally closes:
- turns out we can remove the skip for these test so I will close https://github.com/theislab/moscot/issues/678
- https://github.com/theislab/moscot/issues/703
- https://github.com/theislab/moscot/issues/720