keras-nlp
keras-nlp copied to clipboard
Add Esm
from https://github.com/keras-team/keras-hub/issues/2177 Achieved a smaller error with hf.
import os
os.environ["KERAS_BACKEND"] = "torch"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
from keras import ops
from transformers.models.esm.modeling_esm import EsmAttention as hf_EsmSelfAttention
from transformers import EsmConfig
from esm2.esm2_layers import EsmSelfAttention
import numpy as np
import keras
from transformers.models.esm.modeling_esm import EsmModel
weights_path = "facebook/esm2_t6_8M_UR50D"
hf_model = EsmModel.from_pretrained(weights_path)
hf_model.cuda().eval()
hf_model.embeddings.token_dropout = False
from keras_hub.src.models.esm.esm_backbone import (
ESMBackbone,
)
keras_model = ESMBackbone.from_preset('hf://'+weights_path)
keras_model.summary()
x = ops.array([[1,2,3,4,5]])+1
hf_out = hf_model(x,ops.ones_like(x))[0]
keras_out = keras_model({'token_ids': x})
print(ops.all(ops.isclose(hf_out, keras_out,atol=1e-4)))
ESM Checkpoint Conversion and Numerics Verification Demo (across multiple backends): Notebook Link
Train Demo: Notebook Link