scalable-pytorch-sinkhorn
scalable-pytorch-sinkhorn copied to clipboard
Fast, Memory-Efficient Approximate Wasserstein Distances
This repository contains PyTorch code to compute fast p-Wasserstein distances between d-dimensional point clouds using the Sinkhorn Algorithm.
This implementation uses linear memory overhead and is stable in float32, runs on the GPU, and fully differentiable.
This shows an example of the correspondences between two shapes found by computing the Sinkhorn distance on 200k input points:
How to use:
- Copy the
sinkhorn.pyfile in this repository to your PyTorch codebase. pip install pykeops tqdm- Import
from sinkhorn import sinkhornand use thesinkhornfunction!
Running the example code
Look at example_basic.py for a basic example and example_optimize.py for an example of how to use Sinkhorn in your optimization
NOTE: To run the examples, you need to first run
pip install pykeops tqdm numpy scipy polyscope point-cloud-utils
sinkhorn function documentation
sinkhorn(x: torch.Tensor, y: torch.Tensor, p: float = 2,
w_x: Union[torch.Tensor, None] = None,
w_y: Union[torch.Tensor, None] = None,
eps: float = 1e-3,
max_iters: int = 100, stop_thresh: float = 1e-5,
verbose=False)
Computes the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors. Note that this algorithm can be backpropped through (though this may be slow if using many iterations).
Arguments:
x: A[n, d]shaped tensor representing a d-dimensional point cloud withnpoints (one per row)y: A[m, d]shaped tensor representing a d-dimensional point cloud withmpoints (one per row)p: Which norm to use. Must be an integer greater than 0.w_x: A[n,]shaped tensor of optional weights for the pointsx(Nonefor uniform weights). Note that these must sum to the same value as w_y. Default isNone.w_y: A[m,]shaped tensor of optional weights for the pointsy(Nonefor uniform weights). Note that these must sum to the same value as w_y. Default isNone.eps: The reciprocal of the Sinkhorn entropy regularization parameter.max_iters: The maximum number of Sinkhorn iterations to perform.stop_thresh: Stop if the maximum change in the parameters is below this amountverbose: If set, print a progress bar
Returns:
A triple (d, corrs_x_to_y, corr_y_to_x) where:
dis the approximate p-wasserstein distance between point cloudsxandycorrs_x_to_yis a[n,]-shaped tensor wherecorrs_x_to_y[i]is the index of the approximate correspondence in point cloudyof pointx[i](i.e.x[i]andy[corrs_x_to_y[i]]are a corresponding pair)corrs_y_to_xis a[m,]-shaped tensor wherecorrs_y_to_x[i]is the index of the approximate correspondence in point cloudxofpoint y[j](i.e.y[j]andx[corrs_y_to_x[j]]are a corresponding pair)