esm icon indicating copy to clipboard operation
esm copied to clipboard

Enable running ESM on Mac silicon using MPS

Open imranq2 opened this issue 1 year ago • 3 comments

This PR enables ESM to run on Mac Silicon (M1, M2, M3) using the Metal Performance Shaders (MPS) backend for GPU training acceleration on Mac Silicon.

PyTorch already supports Mac (MPS): https://pytorch.org/docs/stable/notes/mps.html

Note that MPS does not support the embedding operations so the following environment variable has to be set to allow PyTorch to fallback for those operations:

export PYTORCH_ENABLE_MPS_FALLBACK=1

There aren't any tests so I can't add a unit test for this by following an existing pattern. However I tested on my Macbook by running the following test and the above PR made this work.

import os
from huggingface_hub import login, HfApi
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
import torch

# Check if the Hugging Face API token is available in the environment
token = os.getenv("HF_API_TOKEN")

if token:
    # Use the existing token
    api = HfApi(token=token)
    print("Using existing Hugging Face token.")
else:
    # Prompt the user to log in if no token is found
    login()

# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

# Set the device to MPS (for Mac M1/M2) or CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# device = "cpu"
print(f"device: {device}")

# Load the ESM 3.0.4 model
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to(device)

# Check if the model is on MPS
model_device = next(model.parameters()).device
print(f"Model is running on device: {model_device}")

# Example protein sequence
# sequence = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAN___"
# sequence = "___________________________________________________DQATSLRILNNGHAFGSLTTPP___________________________________________________________"
sequence = "___DQA___"

# Create an ESMProtein object with the sequence
protein = ESMProtein(sequence=sequence)

# Generate the sequence prediction (optional, if needed)
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7))
# Print out the predicted sequence
predicted_sequence = protein.sequence
print("Predicted Sequence:")
print(predicted_sequence)

# Generate the secondary structure prediction
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))

# Save the predicted structure to a PDB file
protein.to_pdb("./predicted_structure.pdb")

# Optionally, perform a round-trip design by inverse folding the sequence and recomputing the structure
protein.sequence = None
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8))
protein.coordinates = None
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
protein.to_pdb("./round_tripped_structure.pdb")

print("Secondary structure prediction complete. PDB files saved.")

This prints: Model is running on device: mps:0 confirming that MPS is being used.

And prints this warning that embedding operation fell back to CPU: UserWarning: The operator 'aten::_embedding_bag' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.

imranq2 avatar Sep 02 '24 01:09 imranq2

ha, this is very cool! thanks a lot for the PR. can you revert the .gitignore changes? as a python package, we don't need to include dev env related things.

Removed .gitignore file

imranq2 avatar Oct 09 '24 20:10 imranq2

We have to set up a CLA before merging this PR 🙈 Sorry for not merging it for so long, I'll get to it hopefully soon.

ebetica avatar Oct 11 '24 19:10 ebetica

We have to set up a CLA before merging this PR 🙈 Sorry for not merging it for so long, I'll get to it hopefully soon.

Any updates?

imranq2 avatar Jan 28 '25 20:01 imranq2