ott icon indicating copy to clipboard operation
ott copied to clipboard

Refactoring of Neural related modules in OTT-JAX

Open marcocuturi opened this issue 1 year ago • 0 comments

Given the newfound prominence of several neural approaches that go beyond the W2 / ICNN case, we are currently refactoring the neural part of OTT-JAX. This is going to break a few of the notebooks in the main branch.

TODO includes:

  • [x] Move the Monge gap NB / GW NB dataset loading parts there, to have simpler NBs.
  • [x] Add back the quadratic layer earlier in the computational graph for ICNN.
  • [ ] Add PICNN
  • [x] Switch to an iterator that returns 4 dictionaries, rather than 4 data matrices, to accommodate for conditional FGW settings where various feature spaces are handled.

marcocuturi avatar Nov 20 '23 15:11 marcocuturi