ott
ott copied to clipboard
`scale_cost` does not work for kernels when it's not a `float`.
Whenever we have a kernel, the scale_cost
argument does not work whenever it's not a float
due to the Geometry._cost_matrix
being None
example which throws an error:
kernel = rng.uniform(1, 10, size=(20,20))
lp = LinearProblem(geom=Geometry(kernel_matrix=kernel, scale_cost="mean"))
Sinkhorn()(lp)
As mentioned this is due to self._cost_matrix
being None
here
I think we could replace this by the following code:
@property
def inv_scale_cost(self) -> float:
"""Compute and return inverse of scaling factor for cost matrix or kernel matrix."""
instance = self._kernel_matrix if self._cost_matrix is None else self._cost_matrix
if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)):
return 1.0 / self._scale_cost
self = self._masked_geom(mask_value=jnp.nan)
if self._scale_cost == 'max_cost':
return 1.0 / jnp.nanmax(instance)
if self._scale_cost == 'mean':
return 1.0 / jnp.nanmean(instance)
if self._scale_cost == 'median':
return 1.0 / jnp.nanmedian(instance)
raise ValueError(f'Scaling {self._scale_cost} not implemented.')
The question is whether scaling the kernel makes as much sense as scaling the cost matrix.