ott
ott copied to clipboard
Bug in the way GW handles a KL loss
At the moment, the KL loss in the GW solver applies directly to the cost matrices.
This is not really the way we envisioned it in the original paper http://proceedings.mlr.press/v48/peyre16.pdf, in which we were applying KL to the Kernel matrices, and not the cost matrices.
At the moment, the only way to get around this is quite clumsy, since it involves instantiating a Geometry
whose cost_matrix
is the kernel_matrix
of another, e.g. if I use the old API using the make:
geom_1 = pointcloud.PointCloud(x)
geom_xx = geometry.Geometry(geom_1.kernel_matrix)
_ = gromov_wasserstein.gromov_wasserstein(geom_xx,...
I think that by default using the kl
loss on geometries should mean that the kernel matrices are directly considered (and not costs). This can also impact other areas (e.g. (F)GW barycenters which at this moment take as input cost matrices and not geometries IIUC).