Enable running ESM on Mac silicon using MPS
This PR enables ESM to run on Mac Silicon (M1, M2, M3) using the Metal Performance Shaders (MPS) backend for GPU training acceleration on Mac Silicon.
PyTorch already supports Mac (MPS): https://pytorch.org/docs/stable/notes/mps.html
Note that MPS does not support the embedding operations so the following environment variable has to be set to allow PyTorch to fallback for those operations:
export PYTORCH_ENABLE_MPS_FALLBACK=1
There aren't any tests so I can't add a unit test for this by following an existing pattern. However I tested on my Macbook by running the following test and the above PR made this work.
import os
from huggingface_hub import login, HfApi
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
import torch
# Check if the Hugging Face API token is available in the environment
token = os.getenv("HF_API_TOKEN")
if token:
# Use the existing token
api = HfApi(token=token)
print("Using existing Hugging Face token.")
else:
# Prompt the user to log in if no token is found
login()
# Check that MPS is available
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine.")
# Set the device to MPS (for Mac M1/M2) or CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# device = "cpu"
print(f"device: {device}")
# Load the ESM 3.0.4 model
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to(device)
# Check if the model is on MPS
model_device = next(model.parameters()).device
print(f"Model is running on device: {model_device}")
# Example protein sequence
# sequence = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAN___"
# sequence = "___________________________________________________DQATSLRILNNGHAFGSLTTPP___________________________________________________________"
sequence = "___DQA___"
# Create an ESMProtein object with the sequence
protein = ESMProtein(sequence=sequence)
# Generate the sequence prediction (optional, if needed)
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7))
# Print out the predicted sequence
predicted_sequence = protein.sequence
print("Predicted Sequence:")
print(predicted_sequence)
# Generate the secondary structure prediction
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
# Save the predicted structure to a PDB file
protein.to_pdb("./predicted_structure.pdb")
# Optionally, perform a round-trip design by inverse folding the sequence and recomputing the structure
protein.sequence = None
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8))
protein.coordinates = None
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
protein.to_pdb("./round_tripped_structure.pdb")
print("Secondary structure prediction complete. PDB files saved.")
This prints:
Model is running on device: mps:0
confirming that MPS is being used.
And prints this warning that embedding operation fell back to CPU:
UserWarning: The operator 'aten::_embedding_bag' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.
ha, this is very cool! thanks a lot for the PR. can you revert the .gitignore changes? as a python package, we don't need to include dev env related things.
Removed .gitignore file
We have to set up a CLA before merging this PR 🙈 Sorry for not merging it for so long, I'll get to it hopefully soon.
We have to set up a CLA before merging this PR 🙈 Sorry for not merging it for so long, I'll get to it hopefully soon.
Any updates?