ott icon indicating copy to clipboard operation
ott copied to clipboard

`scale_cost` does not work for kernels when it's not a `float`.

Open MUCDK opened this issue 2 years ago • 2 comments

Whenever we have a kernel, the scale_cost argument does not work whenever it's not a float due to the Geometry._cost_matrix being None

example which throws an error:

kernel = rng.uniform(1, 10, size=(20,20))
lp = LinearProblem(geom=Geometry(kernel_matrix=kernel, scale_cost="mean"))
Sinkhorn()(lp)

As mentioned this is due to self._cost_matrix being None here

I think we could replace this by the following code:

  @property
  def inv_scale_cost(self) -> float:
    """Compute and return inverse of scaling factor for cost matrix or kernel matrix."""
    instance = self._kernel_matrix if self._cost_matrix is None else self._cost_matrix
    if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)):
      return 1.0 / self._scale_cost
    self = self._masked_geom(mask_value=jnp.nan)
    if self._scale_cost == 'max_cost':
      return 1.0 / jnp.nanmax(instance)
    if self._scale_cost == 'mean':
      return 1.0 / jnp.nanmean(instance)
    if self._scale_cost == 'median':
      return 1.0 / jnp.nanmedian(instance)
    raise ValueError(f'Scaling {self._scale_cost} not implemented.')

The question is whether scaling the kernel makes as much sense as scaling the cost matrix.

MUCDK avatar Nov 28 '22 11:11 MUCDK