esm icon indicating copy to clipboard operation
esm copied to clipboard

what is the per residue representation in ESM-C?

Open peiyaoli opened this issue 1 year ago • 1 comments

Hello,

In the old ESM2, we use this code to get per residue embeddings at 33-layer:


# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

In the ESMC, the embedding output is this representation, am I right?

peiyaoli avatar Dec 28 '24 07:12 peiyaoli

My group wrote a simple wrapper for ESMC if you'd like to interface with it like ESM2 huggingface models: https://huggingface.co/Synthyra/ESMplusplus_small

lhallee avatar Jan 02 '25 17:01 lhallee

Embeddings → Transformer blocks → (x) [returned as embeddings] → LayerNorm → sequence_head → logits

The returned embedding is this representation. I don't remember which one is returned in ESM2. We do return all hidden states, so you can use hidden_states[i] to pick the embedding you want.

ebetica avatar Sep 19 '25 20:09 ebetica