esm icon indicating copy to clipboard operation
esm copied to clipboard

new pip install; ESM3 bfloat16 / float32 compatibility error

Open mzio opened this issue 11 months ago • 1 comments

Hi all; thanks for all the updates!

Recently I tried to do a fresh install from pip install esm, but get an type compability error from a forward pass. Model weights are (desirably) loaded in bf16, but somewhere a manual casting to fp32 breaks inference.

Was just wondering if the below is reproducible, and how we can fix; thanks!

In particular, when I do something like:

import torch
from esm.models.esm3 import ESM3
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
model = ESM3.from_pretrained(
      "esm3_sm_open_v1", 
      device=torch.device("cuda"),
)  # Default loads model weights in bf16

tokenizer = EsmSequenceTokenizer()
text_sequence = 'MSGIRELCRSRRGLLRHRRPGTRGGPQDGGPFRGQDPGRGGCAQL'
input_ids = tokenizer(text_sequence, return_tensors='pt')['input_ids']

model.forward(sequence_tokens=input_ids.to(model.device),
      average_plddt=torch.tensor(1., dtype=torch.bfloat16, device=model.device),
      per_res_plddt=torch.tensor(0., dtype=torch.bfloat16, device=model.device),)

We get a RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

I think this is from a manual float recasting here: https://github.com/evolutionaryscale/esm/blob/5a1f4573bc910023775cb5ec575e12e4bd04ebb6/esm/models/esm3.py#L332-L333

I imagine this is intended, so first question is just curious why we upcast to fp32 here?

But regardless when I comment this out, we get another fp32 / bf16 issue with geom_attn (later in esm/layers/blocks.py:147, but i think the lines are different in the main branch) RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

[146] if self.use_geom_attn:
[147]     r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id)

https://github.com/evolutionaryscale/esm/blob/5a1f4573bc910023775cb5ec575e12e4bd04ebb6/esm/layers/blocks.py#L155-L157

mzio avatar Jan 04 '25 19:01 mzio

Keeping open but just sharing a manual patch (a bit hacky to pr / haven't gone thru codebase; but seems to generate structures fine)


Where esm3.py is located, e.g., ~/miniconda3/envs/align-bio/lib/python3.12/site-packages/esm/models/esm3.py, we manually recast:

#  around esm/models/esm3.py:L115 - L124
rbf_16_fn = partial(rbf, v_min=0.0, v_max=1.0, n_bins=16)
# the `masked_fill(padding_mask.unsqueeze(2), 0)` for the two below is unnecessary
# as pad tokens never even interact with the "real" tokens (due to sequence_id)
dtype = self.sequence_embed.weight.dtype   # add
plddt_embed = self.plddt_projection(
    rbf_16_fn(average_plddt).to(dtype=dtype)  # add
)
structure_per_res_plddt = self.structure_per_res_plddt_projection(
    rbf_16_fn(per_res_plddt).to(dtype=dtype)   # add
)

And also here (~/miniconda3/envs/align-bio/lib/python3.12/site-packages/esm/layers/geom_attention.py):

# around esm/layers/geom_attention.py:151
attn_out = attn_out.to(dtype=self.out_proj.weight.dtype)

mzio avatar Jan 09 '25 18:01 mzio