ott
ott copied to clipboard
restructuring neural models + addition of OT-FM and GENOT
This is a PR for
- new base classes neural solvers and models (i.e. neural networks)
- Incorporating unbalancedness and learning the rescaling factors for any neural OT model.
- adding
OTFlowMatching
and, related to this, classes for flows and time samplers - adding
GENOT
(with extension to conditional GENOT) - adding drafts of data loaders.
Following this PR, the implementations of ICNN-based solvers and the Monge Gap model should be adapted and extended to the unbalanced setting.
~~Moreover, wrt typing, I replaced jnp.ndarray
by jax.Array
~~
What remains to be done, but I would prefer to do in a separate PR
- add graph costs to
OTFM
andGENOT
, i.e. functions which compute batch-wise graphs, and compute costs, e.g. geodesic Sinkhorn or convolutional Wasserstein from this. - implementations of ICNN-based solvers and the Monge Gap model should be adapted and extended to the unbalanced setting.