ott
ott copied to clipboard
application of `low-rank` to single-cell data
TL;DR in this colab we provide an example for our failure in obtaining a valid mapping using low-rank
.
problem setup: In this example data set we look into mapping spatial transcriptomics at single cell resolution from mouse embryonic tissues across two time-points. so the quadratic term accounts for distances in spatial coordinates and the linear captures distances in gene-expression.
evaluation of the mapping: As an initial sanity check we look at the cell-transition table, that is the transition matrix with entries grouped by cell types (we asses the row-stochastic, forward
, setting). the naive assumption is that cells of the same type, e.g. brain
will be mapped mainly mapped to themselves. Evaluating the regular FGW
and FGW (unbalanced)
this is indeed what we observe. However, for low-rank
we get a matrix with constant columns. We observed a similar phenomena at different time-points. Comparing the results we can see hints for the constant columns as they are cell-types also favored in the regular regime.
@LaetitiaPapaxanthos here you can also observe the performance of unbalanced with $\tau_a = \tau_b$.
As I show the problem is indeed in gw_unbalanced_correction=True
(with False
it works).
you can obviously play there with everything :)
For the tau_a = tau_b = 0.9
, I noticed that the total mass transported is very low (1e-5), whereas if only 1 is unbalanced, it's fairly high (0.9). Code to reproduce:
from jax.config import config
config.update("jax_enable_x64", True)
import ott
import jax
import numpy as np
from ott.geometry.pointcloud import PointCloud
np.random.seed(0)
x = np.random.normal(size=(64, 3))
y = np.random.normal(size=(128, 3))
xx = np.random.normal(size=(64, 3))
yy = np.random.normal(size=(128, 3))
o, scale_cost = True, 'max_cost'
geom_x = PointCloud(x, online=o, scale_cost=scale_cost)
geom_y = PointCloud(y, online=o, scale_cost=scale_cost)
geom_xy = PointCloud(xx, yy, online=o, scale_cost=scale_cost)
solver = ott.core.gromov_wasserstein.GromovWasserstein(jit=False, epsilon=1e-2, lse_mode=False)
prob = ott.core.quad_problems.QuadraticProblem(geom_x, geom_y,
geom_xy,
tau_a=0.8, tau_b=0.8,
gw_unbalanced_correction=True)
iteration = 0
state = solver.init_state(prob, -1)
linear_pb = prob.update_linearization(state.linear_state, solver.epsilon, state.old_transport_mass)
out = solver.linear_ot_solver(linear_pb)
old_transport_mass = jax.lax.stop_gradient(
state.linear_state.transport_mass()
)
state = state.update(
iteration, out, linear_pb, solver.store_inner_errors, old_transport_mass
)
print(state.linear_state.marginal(0).sum()) # 1.883714535238546e-05
In the next iteration, the solution to the linearized problems contains infs
; this also causes the transport mass sum to be 0
(and makes the scale between the old and the new transport mass NaN
). This only happens when gw_unbalanced_correction=True
.
maybe we can close now?
completed via #128