ott
ott copied to clipboard
Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
``` import jax from ott.tools import soft_sort key = jax.random.PRNGKey(0) x_test = jax.random.normal(key, shape=(20,)) levels = jnp.array([0.4, 0.6]) jax.hessian(soft_sort.quantile,0)(x_test, q=levels) ``` If calculating hessian, it works fine. ``` jax.jacobian(jax.jacobian(soft_sort.quantile,0),0)(x_test, q=levels)...
In ``neuraldual.py``, when ``pos_weights=False`` the weights of network ``f`` should be clipped and the weights of network ``g`` should be penalized in a loss. For clipping, this behaves correctly but...
**Describe the bug** When use_bias is set to False in the PosDefPotentials class, the dimensions of y and kernel do not allow to do the jax.lax.dot_general, that leads to an...
**Is your feature request related to a problem? Please describe.** When implementing new features, we provide tests of the gradients to demonstrate that the new feature is differentiable by autograd....
Hi! Following up on the solutions kindly provided, I'm experimenting with the `batch-size` method. It is very helpful! Now we are able to compute OT problems at least 10 times...
Hi, `jax` seems to reserve all the gpu memory at import. So we cannot see how much memory is used exactly by the ott package from the nvidia panels. Right...
Hi! I'm very interested in computing unbalanced OT problems, where the total mass of one distribution can be significantly smaller than the other. In this case, I wish to get...
Hi, In our ML tasks, the problem of scale is often defined by num_of_training_samples by num_of_validation_samples. Our GPUs currently has 40~80 GB memory per card, which could handle problem of...
Hi, It seems jax has support for multiple GPUs and allows automatically parallelizing the computation over multiple devices via "pmap". In our test cases, when there are multiple GPUs available...