ott
ott copied to clipboard
Refactoring of Neural related modules in OTT-JAX
Given the newfound prominence of several neural approaches that go beyond the W2 / ICNN case, we are currently refactoring the neural
part of OTT-JAX. This is going to break a few of the notebooks in the main
branch.
TODO includes:
- [x] Move the Monge gap NB / GW NB dataset loading parts there, to have simpler NBs.
- [x] Add back the quadratic layer earlier in the computational graph for ICNN.
- [ ] Add PICNN
- [x] Switch to an iterator that returns 4 dictionaries, rather than 4 data matrices, to accommodate for conditional FGW settings where various feature spaces are handled.