esm
esm copied to clipboard
what is the per residue representation in ESM-C?
Hello,
In the old ESM2, we use this code to get per residue embeddings at 33-layer:
# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval() # disables dropout for deterministic results
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
("protein3", "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
# Extract per-residue representations (on CPU)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]
In the ESMC, the embedding output is this representation, am I right?
My group wrote a simple wrapper for ESMC if you'd like to interface with it like ESM2 huggingface models: https://huggingface.co/Synthyra/ESMplusplus_small
Embeddings → Transformer blocks → (x) [returned as embeddings] → LayerNorm → sequence_head → logits
The returned embedding is this representation. I don't remember which one is returned in ESM2. We do return all hidden states, so you can use hidden_states[i] to pick the embedding you want.