POT
POT copied to clipboard
Can not batch `ot.emd2` via `torch.vmap`
Describe the bug
As my datapoints are empirical distributions I want to use the Wasserstein distance as a loss function over a batch of shape (n_batch, n_points, dimension)
. Standard way to make functions that take a batch as an input is torch.vmap
, yet I get the error described below.
To Reproduce
def wasserstein2_loss(X, Y):
n, m = X.shape[0], Y.shape[0]
a = torch.ones(n) / n
b = torch.ones(m) / m
M = ot.dist(X, Y, metric="sqeuclidean")
return ot.emd2(a, b, M) ** 0.5
wasserstein2_loss_batched = torch.vmap(wasserstein2_loss)
W2 = wasserstein2_loss_batched(X, Y) # should be an array of shape `n_batch`
Error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 W2 = wasserstein2_loss_batched(X, Y)
File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:434, in vmap.<locals>.wrapped(*args, **kwargs)
430 return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
431 args_spec, out_dims, randomness, **kwargs)
433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
435 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
436 )
File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:39, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
36 @functools.wraps(f)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:619, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
617 try:
618 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 619 batched_outputs = func(*batched_inputs, **kwargs)
620 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
621 finally:
Cell In[4], line 13, in wasserstein2_loss(X, Y)
11 b = torch.ones(m) / m
12 M = ot.dist(X, Y, metric="sqeuclidean")
---> 13 return wasserstein_distance(a, b, M) ** 0.5
File /usr/local/lib/python3.10/dist-packages/ot/lp/__init__.py:488, in emd2(a, b, M, processes, numItermax, log, return_matrix, center_dual, numThreads, check_marginals)
485 nx = get_backend(M0, a0, b0)
487 # convert to numpy
--> 488 M, a, b = nx.to_numpy(M, a, b)
490 a = np.asarray(a, dtype=np.float64)
491 b = np.asarray(b, dtype=np.float64)
File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in Backend.to_numpy(self, *arrays)
205 return self._to_numpy(arrays[0])
206 else:
--> 207 return [self._to_numpy(array) for array in arrays]
File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in <listcomp>(.0)
205 return self._to_numpy(arrays[0])
206 else:
--> 207 return [self._to_numpy(array) for array in arrays]
File /usr/local/lib/python3.10/dist-packages/ot/backend.py:1763, in TorchBackend._to_numpy(self, a)
1761 if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
1762 return np.array(a)
-> 1763 return a.cpu().detach().numpy()
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Expected behavior
Make POT distance functions batchable via torch.vmap
, seems Sinkhorn distance code has this problem too.