esm
esm copied to clipboard
ESM models fail on CPU when flash_attn is installed
When running ESM models on CPU with flash_attn installed, inference fails with CUDA-related errors despite explicitly setting the device to CPU.
Steps to reproduce: On a CPU machine with flash-attention installed
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
protein = ESMProtein(sequence="AAAAA")
client = ESMC.from_pretrained("esmc_300m").to("cpu")
protein_tensor = client.encode(protein)
logits_output = client.logits(protein_tensor, LogitsConfig(sequence=True))
Error
RuntimeError: invalid argument to exchangeDevice
I tried setting client._use_flash_attn = False manually after loading but it fails with a different error related to tensor dimensions, suggesting Flash Attention dependencies remain active.
Suggested Fix
Add an automatic check to disable Flash Attention when running on CPU, or provide use_flash_attn parameter in from_pretrained()