ott icon indicating copy to clipboard operation
ott copied to clipboard

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.

Results 113 ott issues
Sort by recently updated
recently updated
newest added

``` import jax from ott.tools import soft_sort key = jax.random.PRNGKey(0) x_test = jax.random.normal(key, shape=(20,)) levels = jnp.array([0.4, 0.6]) jax.hessian(soft_sort.quantile,0)(x_test, q=levels) ``` If calculating hessian, it works fine. ``` jax.jacobian(jax.jacobian(soft_sort.quantile,0),0)(x_test, q=levels)...

question

In ``neuraldual.py``, when ``pos_weights=False`` the weights of network ``f`` should be clipped and the weights of network ``g`` should be penalized in a loss. For clipping, this behaves correctly but...

bug

**Describe the bug** When use_bias is set to False in the PosDefPotentials class, the dimensions of y and kernel do not allow to do the jax.lax.dot_general, that leads to an...

bug

**Is your feature request related to a problem? Please describe.** When implementing new features, we provide tests of the gradients to demonstrate that the new feature is differentiable by autograd....

enhancement

Use `jax.scipy.cho_{factor,solve}`

enhancement

Hi! Following up on the solutions kindly provided, I'm experimenting with the `batch-size` method. It is very helpful! Now we are able to compute OT problems at least 10 times...

question

Hi, `jax` seems to reserve all the gpu memory at import. So we cannot see how much memory is used exactly by the ott package from the nvidia panels. Right...

question

Hi! I'm very interested in computing unbalanced OT problems, where the total mass of one distribution can be significantly smaller than the other. In this case, I wish to get...

question

Hi, In our ML tasks, the problem of scale is often defined by num_of_training_samples by num_of_validation_samples. Our GPUs currently has 40~80 GB memory per card, which could handle problem of...

question

Hi, It seems jax has support for multiple GPUs and allows automatically parallelizing the computation over multiple devices via "pmap". In our test cases, when there are multiple GPUs available...

enhancement
question