esm icon indicating copy to clipboard operation
esm copied to clipboard

How to enable batch training?

Open yuliangyan0807 opened this issue 5 months ago • 0 comments

I wrote a script to utilize the ESMc model as follows:

import torch
import torch.nn as nn
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
from typing import List

class ESMCProteinEncoder(nn.Module):
    def __init__(self, model_name="esmc_600m"):
        """
        ESMC Protein Encoder that only encodes protein sequences and applies mean pooling
        :param model_name: The pre-trained model to load from HuggingFace (default is "esmc_600m")
        """
        super().__init__()
        # Load the ESMC model, frozen by default
        self.client = ESMC.from_pretrained(model_name)
        
        # Freeze the parameters to avoid modifying during training
        for param in self.client.parameters():
            param.requires_grad = False

    def forward(self, seq: List[str]):
        """
        Encodes the input protein sequence and applies mean pooling on the embeddings
        :param seq: Protein sequence as a string
        :return: Mean-pooled protein embeddings from ESMC model
        """
        # Ensure the input sequence is in the correct format
        protein = ESMProtein(sequence=seq)

        # Get the encoded tensor from the ESMC model
        tensor = self.client.encode(protein)

        # Get the embeddings (logits) from the ESMC model
        logits_output = self.client.logits(tensor, LogitsConfig(sequence=True, return_embeddings=True))
        
        # Extract the embeddings and drop the BOS/EOS tokens
        embeddings = logits_output.embeddings[:, 1:-1, :].to(torch.float32)  # Drop BOS/EOS

        # Apply mean pooling to the embeddings (along the sequence dimension)
        pooled_embeddings = torch.mean(embeddings, dim=1)
        
        return pooled_embeddings



encoder = ESMCProteinEncoder(model_name="esmc_600m")
seq = ["AAAAAA", "GGGGGG", "CCCCCC"]
pooled_embeddings = encoder(seq)
print(pooled_embeddings.shape)

When I try to input a list, e.g., ["AAAAAA", "GGGGGG", "CCCCCC"], into the model, the following error occurred:

Traceback (most recent call last):
  File "/home/yuliangyan/Code/Trust-App-AI-Lab/prot_learn/esm3_test.py", line 60, in <module>
    pooled_embeddings = encoder(seq)
                        ^^^^^^^^^^^^
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuliangyan/Code/Trust-App-AI-Lab/prot_learn/esm3_test.py", line 43, in forward
    tensor = self.client.encode(protein)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/esm/models/esmc.py", line 180, in encode
    sequence_tokens = self._tokenize([input.sequence])[0]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/esm/models/esmc.py", line 105, in _tokenize
    encoding.tokenize_sequence(x, self.tokenizer, add_special_tokens=True)
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/esm/utils/encoding.py", line 53, in tokenize_sequence
    sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token)
               ^^^^^^^^^^^^^^^^
AttributeError: 'list' object has no attribute 'replace'

How can I fix this bug? Thanks!

yuliangyan0807 avatar Jul 30 '25 10:07 yuliangyan0807