POT icon indicating copy to clipboard operation
POT copied to clipboard

Can not batch `ot.emd2` via `torch.vmap`

Open oleg-kachan opened this issue 8 months ago • 1 comments

Describe the bug

As my datapoints are empirical distributions I want to use the Wasserstein distance as a loss function over a batch of shape (n_batch, n_points, dimension). Standard way to make functions that take a batch as an input is torch.vmap, yet I get the error described below.

To Reproduce

def wasserstein2_loss(X, Y):
    n, m = X.shape[0], Y.shape[0]
    a = torch.ones(n) / n
    b = torch.ones(m) / m
    M = ot.dist(X, Y, metric="sqeuclidean")
    return ot.emd2(a, b, M) ** 0.5

wasserstein2_loss_batched = torch.vmap(wasserstein2_loss)
W2 = wasserstein2_loss_batched(X, Y) # should be an array of shape `n_batch`

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 W2 = wasserstein2_loss_batched(X, Y)

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:434, in vmap.<locals>.wrapped(*args, **kwargs)
    430     return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
    431                          args_spec, out_dims, randomness, **kwargs)
    433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
    435     func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
    436 )

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:39, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     36 @functools.wraps(f)
     37 def fn(*args, **kwargs):
     38     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39         return f(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:619, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    617 try:
    618     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 619     batched_outputs = func(*batched_inputs, **kwargs)
    620     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    621 finally:

Cell In[4], line 13, in wasserstein2_loss(X, Y)
     11 b = torch.ones(m) / m
     12 M = ot.dist(X, Y, metric="sqeuclidean")
---> 13 return wasserstein_distance(a, b, M) ** 0.5

File /usr/local/lib/python3.10/dist-packages/ot/lp/__init__.py:488, in emd2(a, b, M, processes, numItermax, log, return_matrix, center_dual, numThreads, check_marginals)
    485 nx = get_backend(M0, a0, b0)
    487 # convert to numpy
--> 488 M, a, b = nx.to_numpy(M, a, b)
    490 a = np.asarray(a, dtype=np.float64)
    491 b = np.asarray(b, dtype=np.float64)

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in Backend.to_numpy(self, *arrays)
    205     return self._to_numpy(arrays[0])
    206 else:
--> 207     return [self._to_numpy(array) for array in arrays]

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in <listcomp>(.0)
    205     return self._to_numpy(arrays[0])
    206 else:
--> 207     return [self._to_numpy(array) for array in arrays]

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:1763, in TorchBackend._to_numpy(self, a)
   1761 if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
   1762     return np.array(a)
-> 1763 return a.cpu().detach().numpy()

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Expected behavior

Make POT distance functions batchable via torch.vmap, seems Sinkhorn distance code has this problem too.

oleg-kachan avatar Oct 11 '23 10:10 oleg-kachan