RAM crashes by fine tuning with pretrain ESM model on large dataset.
Hello, I have a very large dataset of protein sequences (10000 protein sequence) whose representation I want to generate using pretrained ESM model. Every protein sequence length is also long. I have used this below code.
pip install fair-esm
`import torch
import esm
Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() batch_converter = alphabet.get_batch_converter()`
`batch_labels, batch_strs, batch_tokens = batch_converter(data) batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
Extract per-residue representations (on CPU)
with torch.no_grad(): results = model(batch_tokens, repr_layers=[33], return_contacts=True) token_representations = results["representations"][33]`
This above code is to generate per-residue representations on CPU. Is there any way to utilize the ESM pretrained model on cuda. My RAM crashes if I run it on CPU. I have a 12GB RAM. I want to generate represenations of 10,000 proteins. Is there anyway to tackle this issue.
Hello related to the question above, is there a way to change the batch size of the model to generate the embedding using the snippet of code above or the one in the README file? Thanks
Hello related to the question above, is there a way to change the batch size of the model to generate the embedding using the snippet of code above or the one in the README file? Thanks
`def get_ESM2_embedding(model_path, seqs, pooling): '''
:param model_path:
:param seqs:
:param pooling:
:return:
'''
# seqs = ["QERLKSIVRILE"] # 长度12
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = EsmModel.from_pretrained(model_path)
inputs = tokenizer(seqs, return_tensors="pt", padding=True, truncation=True)
# truncation to max model input length
attention_mask = inputs['attention_mask']
with torch.no_grad():
outputs = model(**inputs)
# # print(outputs.keys()) # (['last_hidden_state', 'pooler_output'])
print(outputs.last_hidden_state.shape) # torch.Size([2, 44, 480])
# pooling = 'mean'
output = ''
if pooling == 'cls':
output = outputs.last_hidden_state[:, 0, :] # 这是获取得到的 cls 蛋白整体代表的向量 ## 长480的一维向量;
print("cls::", output.shape) # torch.Size([1, 480])
elif pooling == "pooler_output":
output = outputs.pooler_output
print("pooler_output::", output.shape) # torch.Size([1, 480])
elif pooling == 'mean':
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).float() # 将注意力掩码转换为浮点数类型
valid_token_count = attention_mask.sum(1).clamp(min=1e-9).unsqueeze(1) # 有效标记的数量
output = torch.sum(token_embeddings * input_mask_expanded, dim=1) / valid_token_count
print("mean_output::", output.shape) # torch.Size([1, 480]) torch.Size([2, 480])
elif pooling == 'max':
token_embeddings = outputs.last_hidden_state
masked_embeddings = token_embeddings + (1 - attention_mask.unsqueeze(-1).float()) * (-1e9)
output, _ = torch.max(masked_embeddings, 1)
print("max_output::", output.shape) # torch.Size([1, 480]) torch.Size([2, 480])
return output
def main(): model_path = "./ESM_huggingFace/ESM2weight/esm2_t12_35M_UR50D" seqs = ["QERLKSIVRILE"] poolinng = 'mean' # pooler_output, cls, mean, max, other methods to develop output = get_ESM2_embedding(model_path, seqs, poolinng) print(output.shape)
if name == "main": main()`
hope it helps
Yes, thanks alot!