esm icon indicating copy to clipboard operation
esm copied to clipboard

Inconsistent result between multi-MSAs and single-MSAs

Open huangtinglin opened this issue 2 years ago • 2 comments

Bug description I am running the pretrained MSA transformer (esm_msa1b_t12_100M_UR50S) on some MSAs with different numbers and lengths to generate the representations. Following the example shown in README, I apply batch_converter to process the MSAs and obtain the token tensor with padding. But the representations generated by the transformer don't match the results when the MSAs are fed into the model one at a time.

Reproduction steps Here is a simple example.

import esm
import torch

model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
batch_converter = alphabet.get_batch_converter()
model.eval()

MSAs1 = [
            ("protein1 test", "MKTVRQERLKSIVRILERSKEPVSGAQLAE"),
            ("protein2 test", "KALTARQQEVFDLIRDHISQTGMPPTRAEI"),
            ("protein3 test","KALTARQQEVFDLIRDBISQTGMPPTRAEI"),
        ]
MSAs2 = [
            ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAAA"),
            ("protein2", "KALTARQQEVFDLIRDHISQTGMPPCDC"),
        ]

MSAs_group1 = [MSAs1, MSAs2]
MSAs_group2 = [MSAs2]
_, _, batch_tokens1 = batch_converter(MSAs_group1)
_, _, batch_tokens2 = batch_converter(MSAs_group2)

with torch.no_grad():
    all_info = model(batch_tokens1, repr_layers=[12], need_head_weights=True)
    all_info1 = model(batch_tokens2, repr_layers=[12], need_head_weights=True)

repres1 = all_info["representations"][12]  # [2, 3, 31, 768]
repres2 = all_info1["representations"][12]  # [1, 2, 29, 768]

repres2_shape = repres2.shape
repres1 = repres1[1:, :repres2_shape[1], :repres2_shape[2]]  # [1, 2, 29, 768]
print("the difference between representations of MSAs2 generated with MSAs_group1 and MSAs_group2: ", (repres1 - repres2).sum())

Expected behavior A target MSA's representation produced by feeding it into the MSA transformer with and without the other MSAs is identical.

Logs

the difference between representations of MSAs2 generated with MSAs_group1 and MSAs_group2:  tensor(3.8772)

Additional context Add any other context about the problem here. (like proxy settings, network setup, overall goals, etc.)

huangtinglin avatar Jul 04 '22 15:07 huangtinglin

@huangtinglin, in testing some code I've been writing recently using esm1_t6, I've noticed that different batch sizes can sometimes give you slightly different results. I'm extracting embeddings (CPU bound), and I observe the behavior whether I'm running the extract.py script from ESM or my own code that wraps the model. My best guess right now is that it's something to do with which algorithms PyTorch uses based on workload (see here). What are some summary stats on your differences other than the sum (mean, median, min, max)? If they're small, then I'd be curious if your observations are due to floating point errors with different algorithms being used for the different workloads. The differences that I typically observe are around 1e-5 or smaller -- MSA1b is much larger than esm1_t6, though, so I wouldn't be surprised if slightly larger differences could happen with MSA1b as it performs more operations to calculate the representations.

brucejwittmann avatar Dec 01 '22 01:12 brucejwittmann

Thanks, @brucejwittmann. Actually, I found that the error is due to the scaling factor which is related to the number of rows. I have created a new issue regarding this matter, which can be found at https://github.com/facebookresearch/esm/issues/491.

huangtinglin avatar Feb 27 '23 00:02 huangtinglin