mmvec
mmvec copied to clipboard
MMvec refactor
We're going to go pytorch OR numpyro. The framework will have the following skeleton
model.py (mmvec.py)
import torch
import torch.nn
from torch.distributions import Multinomial
class MMvec(nn.Module):
def __init__(self, num_microbes, num_metabolites, latent_dim):
self.encoder = nn.Embedding(num_microbes, latent_dim)
self.decoder = nn.Sequential([nn.Linear(latent_dim, num_metabolite), nn.Softmax()])
# TODO : may want to have a better softmax
def forward(X, Y):
""" X is one-hot encodings (B x num_microbes). Y is metabolite abundances (B x num_metabolites). B is the batch size"""
z = self.encoder(X)
pred_y = self.decoder(z)
lp = Multinomial(pred_y).log_prob(Y).mean()
return lp
train.py
(could use Pytorch lightning)
The wishlist
- Early stopping (see video for example)
- Arviz for diagnostics diagnostics
- Typing would be great. See torchtyping
- Torchtests could be cool also. See torchtest
- Being Bayesian would be nice. SWAG is the laziest approach
First pass:
https://github.com/Keegan-Evans/mmvec/blob/pytorch-refactor/examples/refactor/041222pytorchdraft.ipynb
Hi @Keegan-Evans this is a great first pass. The basic architecture is there, and it looks like the gradient descent is working.
There are a couple of things that we'll want to try out
- Getting the unittests to pass at https://github.com/biocore/mmvec/blob/master/mmvec/tests/test_multimodal.py#L18
- Doubling checking the soils experiment at https://github.com/biocore/mmvec/blob/master/mmvec/tests/test_multimodal.py#L76
We may want to revisit the decoder architecture -- the softmax identifability may end up biting us when interpreting the decoder parameters (probably also why you see word2vec applications interpreting only the encoder parameters, rather than decoder parameters). In MMvec, we used the inverse ALR transform, which would look something like
class LinearALR(nn.Module):
def __init__(self, input_dim, output_dim):
W = nn.Parameter(torch.randn(output_dim - 1, input_dim))
b = nn.Parameter(torch.randn(output_dim - 1))
def forward(self, x):
b = x.shape[0]
z = torch.zeros((b, 1))
x = torch.stack((z, x), axis=1)
y = W @ x + b
return F.softmax(y, axis=1)
ALR does have some issues (the factorization isn't going to be super accurate). We had to do some redundant computation in the original MMvec code to do SVD afterwards - but its most definitely going to lose some information (i.e. if we have k principal components, we'll recover less than k PCs due to the ALR). Using the ILR transform can help with this (see our preprint here). That code would look like something as follows
from gneiss.cluster import random_linkage
from gneiss.balances import sparse_balance_basis
class LinearILR(nn.Module):
def __init__(self, input_dim, output_dim):
tree = random_linkage(output_dim) # pick random tree it doesn't really matter tbh
basis = sparse_balance_basis(tree)
Psi = torch.sparse_coo_tensor(
indices.copy(), basis.data.astype(np.float32).copy(),
requires_grad=False).coalesce()
self.linear = nn.Linear(input_dim, output_dim - 1)
self.register_buffer('Psi', Psi)
def forward(self, x):
y = self.linear(x)
logy = (self.Psi.t() @ y.t()).t()
return F.softmax(logy, axis=1)
We may want to have some small benchmarks with the unittests with all of these 3 approaches -- based on what I've seen (since the MMvec paper), there are going to be differences. And my hunch tells me that ILR is going to be more convenient. Of course, we can talk offline about this.
Here is a statistical significance test for differential abundance
https://github.com/flatironinstitute/q2-matchmaker/blob/main/q2_matchmaker/_stats.py#L46
@mortonjt,
In your original model, when constructing the ranks, you stack zeros onto the product of U and V here? Is it necessary or expedient for how you performed the ALR?
Hi @Keegan-Evans this is necessary if the V matrix is in ALR coordinates. If we're using ILR, then it would be U @ V @ Psi.T
(if Psi
is D x D - 1)
@mortonjt, I was working on it this morning and realized that might be the case, thanks for the reply!