mmvec icon indicating copy to clipboard operation
mmvec copied to clipboard

MMvec refactor

Open mortonjt opened this issue 2 years ago • 6 comments

We're going to go pytorch OR numpyro. The framework will have the following skeleton

model.py (mmvec.py)

import torch
import torch.nn
from torch.distributions import Multinomial

class MMvec(nn.Module):
    def __init__(self, num_microbes, num_metabolites, latent_dim):
        self.encoder = nn.Embedding(num_microbes, latent_dim)
        self.decoder = nn.Sequential([nn.Linear(latent_dim, num_metabolite), nn.Softmax()])
        # TODO : may want to have a better softmax

    def forward(X, Y):
        """ X is one-hot encodings (B x num_microbes).  Y is metabolite abundances (B x num_metabolites).  B is the batch size""" 
        z = self.encoder(X)
        pred_y = self.decoder(z)
        lp = Multinomial(pred_y).log_prob(Y).mean()
        return lp

train.py (could use Pytorch lightning)

The wishlist

  • Early stopping (see video for example)
  • Arviz for diagnostics diagnostics
  • Typing would be great. See torchtyping
  • Torchtests could be cool also. See torchtest
  • Being Bayesian would be nice. SWAG is the laziest approach

mortonjt avatar Mar 22 '22 20:03 mortonjt

First pass:

https://github.com/Keegan-Evans/mmvec/blob/pytorch-refactor/examples/refactor/041222pytorchdraft.ipynb

Keegan-Evans avatar Apr 12 '22 22:04 Keegan-Evans

Hi @Keegan-Evans this is a great first pass. The basic architecture is there, and it looks like the gradient descent is working.

There are a couple of things that we'll want to try out

  1. Getting the unittests to pass at https://github.com/biocore/mmvec/blob/master/mmvec/tests/test_multimodal.py#L18
  2. Doubling checking the soils experiment at https://github.com/biocore/mmvec/blob/master/mmvec/tests/test_multimodal.py#L76

We may want to revisit the decoder architecture -- the softmax identifability may end up biting us when interpreting the decoder parameters (probably also why you see word2vec applications interpreting only the encoder parameters, rather than decoder parameters). In MMvec, we used the inverse ALR transform, which would look something like

class LinearALR(nn.Module):
    def __init__(self, input_dim, output_dim):
        W = nn.Parameter(torch.randn(output_dim - 1, input_dim))
        b = nn.Parameter(torch.randn(output_dim - 1))

    def forward(self, x):
        b = x.shape[0]
        z = torch.zeros((b, 1))
        x = torch.stack((z, x), axis=1)
        y = W @ x + b
        return F.softmax(y, axis=1)

ALR does have some issues (the factorization isn't going to be super accurate). We had to do some redundant computation in the original MMvec code to do SVD afterwards - but its most definitely going to lose some information (i.e. if we have k principal components, we'll recover less than k PCs due to the ALR). Using the ILR transform can help with this (see our preprint here). That code would look like something as follows

from gneiss.cluster import random_linkage
from gneiss.balances import sparse_balance_basis

class LinearILR(nn.Module):
    def __init__(self, input_dim, output_dim):
        tree = random_linkage(output_dim)  # pick random tree it doesn't really matter tbh
        basis = sparse_balance_basis(tree)
        Psi = torch.sparse_coo_tensor(
            indices.copy(), basis.data.astype(np.float32).copy(),
            requires_grad=False).coalesce()
        self.linear = nn.Linear(input_dim, output_dim - 1)
        self.register_buffer('Psi', Psi)

    def forward(self, x):
        y = self.linear(x)
        logy = (self.Psi.t() @ y.t()).t()
        return F.softmax(logy, axis=1)

We may want to have some small benchmarks with the unittests with all of these 3 approaches -- based on what I've seen (since the MMvec paper), there are going to be differences. And my hunch tells me that ILR is going to be more convenient. Of course, we can talk offline about this.

mortonjt avatar Apr 13 '22 02:04 mortonjt

Here is a statistical significance test for differential abundance

https://github.com/flatironinstitute/q2-matchmaker/blob/main/q2_matchmaker/_stats.py#L46

mortonjt avatar Apr 19 '22 18:04 mortonjt

@mortonjt,

In your original model, when constructing the ranks, you stack zeros onto the product of U and V here? Is it necessary or expedient for how you performed the ALR?

Keegan-Evans avatar Apr 21 '22 20:04 Keegan-Evans

Hi @Keegan-Evans this is necessary if the V matrix is in ALR coordinates. If we're using ILR, then it would be U @ V @ Psi.T (if Psi is D x D - 1)

mortonjt avatar Apr 22 '22 14:04 mortonjt

@mortonjt, I was working on it this morning and realized that might be the case, thanks for the reply!

Keegan-Evans avatar Apr 22 '22 15:04 Keegan-Evans