esm icon indicating copy to clipboard operation
esm copied to clipboard

RAM crashes by fine tuning with pretrain ESM model on large dataset.

Open amalislam675 opened this issue 2 years ago • 3 comments

Hello, I have a very large dataset of protein sequences (10000 protein sequence) whose representation I want to generate using pretrained ESM model. Every protein sequence length is also long. I have used this below code.

pip install fair-esm `import torch import esm

Load ESM-2 model

model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() batch_converter = alphabet.get_batch_converter()`

`batch_labels, batch_strs, batch_tokens = batch_converter(data) batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

Extract per-residue representations (on CPU)

with torch.no_grad(): results = model(batch_tokens, repr_layers=[33], return_contacts=True) token_representations = results["representations"][33]`

This above code is to generate per-residue representations on CPU. Is there any way to utilize the ESM pretrained model on cuda. My RAM crashes if I run it on CPU. I have a 12GB RAM. I want to generate represenations of 10,000 proteins. Is there anyway to tackle this issue.

amalislam675 avatar May 10 '23 03:05 amalislam675

Hello related to the question above, is there a way to change the batch size of the model to generate the embedding using the snippet of code above or the one in the README file? Thanks

Amels404 avatar Sep 15 '23 11:09 Amels404

Hello related to the question above, is there a way to change the batch size of the model to generate the embedding using the snippet of code above or the one in the README file? Thanks

`def get_ESM2_embedding(model_path, seqs, pooling): '''

:param model_path:
:param seqs:
:param pooling:
:return:
'''
# seqs = ["QERLKSIVRILE"] # 长度12
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = EsmModel.from_pretrained(model_path)
inputs = tokenizer(seqs, return_tensors="pt", padding=True, truncation=True)
# truncation to max model input length
attention_mask = inputs['attention_mask']
with torch.no_grad():
    outputs = model(**inputs)
# # print(outputs.keys()) # (['last_hidden_state', 'pooler_output'])
print(outputs.last_hidden_state.shape)  # torch.Size([2, 44, 480])

# pooling = 'mean'
output = ''
if pooling == 'cls':
    output = outputs.last_hidden_state[:, 0, :]  # 这是获取得到的 cls 蛋白整体代表的向量 ## 长480的一维向量;
    print("cls::", output.shape)  # torch.Size([1, 480])

elif pooling == "pooler_output":
    output = outputs.pooler_output
    print("pooler_output::", output.shape)  # torch.Size([1, 480])

elif pooling == 'mean':
    token_embeddings = outputs.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).float()  # 将注意力掩码转换为浮点数类型
    valid_token_count = attention_mask.sum(1).clamp(min=1e-9).unsqueeze(1)  # 有效标记的数量
    output = torch.sum(token_embeddings * input_mask_expanded, dim=1) / valid_token_count
    print("mean_output::", output.shape)  # torch.Size([1, 480]) torch.Size([2, 480])

elif pooling == 'max':
    token_embeddings = outputs.last_hidden_state
    masked_embeddings = token_embeddings + (1 - attention_mask.unsqueeze(-1).float()) * (-1e9)
    output, _ = torch.max(masked_embeddings, 1)
    print("max_output::", output.shape)  # torch.Size([1, 480])  torch.Size([2, 480])

return output

def main(): model_path = "./ESM_huggingFace/ESM2weight/esm2_t12_35M_UR50D" seqs = ["QERLKSIVRILE"] poolinng = 'mean' # pooler_output, cls, mean, max, other methods to develop output = get_ESM2_embedding(model_path, seqs, poolinng) print(output.shape)

if name == "main": main()`

hope it helps

ZJL0111 avatar Jan 11 '24 06:01 ZJL0111

Yes, thanks alot!

Amels404 avatar Jan 22 '24 10:01 Amels404