ModernBertModel works on the CPU but fails on the GPU
Hello everyone,
My problem is that ModernBertModel fails to return a valid output when I use the GPU instead of the CPU. The following code returns a valid output:
import torch
from transformers import AutoTokenizer, ModernBertModel
model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ModernBertModel.from_pretrained(model_id)
device = torch.device("cpu")
model.to(device)
texts = ["The capital of France is Paris.", "The capital of Germany is Berlin."]
inputs = tokenizer(
text=texts,
add_special_tokens=True,
padding='max_length',
truncation=True,
max_length=768,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
print(outputs.last_hidden_state)
Output:
tensor([[[ 0.2510, -0.8900, -0.7447, ..., -0.6569, 0.2809, -0.5663],
[ 0.1292, 0.0536, 0.2478, ..., 0.1400, -0.1059, 0.0981],
[-0.0945, -1.2089, -0.5087, ..., -0.0810, 1.4614, -0.1214],
...,
[ 1.5802, -0.2266, 0.8008, ..., -0.8563, -0.0378, -0.6842],
[ 1.6365, -0.2077, 0.7667, ..., -0.8660, -0.0537, -0.6460],
[ 1.6404, -0.1780, 0.7846, ..., -0.8497, -0.0268, -0.6155]],
[[ 0.3872, -0.9977, -0.8920, ..., -0.7293, 0.5094, -0.5080],
[-0.1917, -0.8092, -0.3774, ..., -1.0475, -0.4196, 0.1802],
[-0.0937, -1.1293, -0.8068, ..., 0.4551, 1.5275, -0.0922],
...,
[ 1.7813, 0.2581, 0.6624, ..., -1.0199, -0.1711, -1.1627],
[ 1.8317, 0.3041, 0.6434, ..., -1.0328, -0.1824, -1.1392],
[ 1.8517, 0.3272, 0.6883, ..., -0.9966, -0.1606, -1.1124]]],
grad_fn=<NativeLayerNormBackward0>)
But when I switch to the GPU I get a tensor with NaNs:
import torch
from transformers import AutoTokenizer, ModernBertModel
model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ModernBertModel.from_pretrained(model_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
texts = ["The capital of France is Paris.", "The capital of Germany is Berlin."]
inputs = tokenizer(
text=texts,
add_special_tokens=True,
padding='max_length',
truncation=True,
max_length=768,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
print(outputs.last_hidden_state)
Output:
tensor([[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',
grad_fn=<NativeLayerNormBackward0>)
Do you have an idea what the problem might be?
I am having this exact issue. Fails on CUDA, works on MPS:
Input IDs: tensor([[50281, 510, 28983, ..., 50283, 50283, 50283],
[50281, 688, 13524, ..., 50283, 50283, 50283],
[50281, 5707, 25280, ..., 50283, 50283, 50283],
...,
[50281, 20411, 10833, ..., 50283, 50283, 50283],
[50281, 10772, 41316, ..., 50283, 50283, 50283],
[50281, 4749, 35982, ..., 50283, 50283, 50283]], device='cuda:0')
Attention mask: tensor([[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0]], device='cuda:0')
Sliding window mask: None
Position IDs: None
Inputs embeds: None
Outputs:
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
...,
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]],
Update: I suspected that this had to do with the attn_mechanism. By default it was running as sdpa. I decided to run with flash_attention_2 which launched some additional installs to resolve dependencies (flash-attn andtriton). Then I downgraded to Python 3.11 and reinstalled my env because I was getting "Dynamo is not supported on Python 3.12+."
Now it works with all attn_impl's ( sdpa, flash_attention_2, and eager). There's probably some package difference between 3.11 and 3.12 causing the issue.
I am facing the same problem. It works with batch size=1 on cuda, but anything larger returns zero. Any solution on this one?
@KendallPark could you post versions of torch, flash-attn and triton you used successfully?
same issue
torch 2.2 without falsh-attn
absl-py 2.1.0 aiofiles 23.2.1 aiohappyeyeballs 2.6.1 aiohttp 3.12.15 aiosignal 1.4.0 annotated-types 0.7.0 anyio 4.9.0 async-timeout 5.0.1 attrs 25.3.0 beautifulsoup4 4.13.3 blinker 1.9.0 blis 1.3.0 blobfile 3.0.0 catalogue 2.0.10 certifi 2025.1.31 charset-normalizer 3.4.1 chumpy 0.70 click 8.1.8 clip 1.0 cloudpathlib 0.21.1 cmake 3.31.4 confection 0.1.5 cycler 0.12.1 cymem 2.0.11 Cython 3.0.12 datasets 4.0.0 dill 0.3.8 einops 0.6.1 en-core-web-sm 3.2.0 exceptiongroup 1.2.2 fastapi 0.115.12 ffmpy 0.3.1 filelock 3.17.0 Flask 3.1.0 frozenlist 1.7.0 fsspec 2025.3.0 ftfy 6.1.1 gdown 5.2.0 gradio 5.23.1 gradio_client 1.8.0 groovy 0.1.2 grpcio 1.54.2 h11 0.14.0 hf-xet 1.1.7 hjson 3.1.0 httpcore 1.0.7 httpx 0.28.1 huggingface-hub 0.34.4 idna 3.10 imageio 2.37.0 imageio-ffmpeg 0.6.0 importlib-metadata 5.0.0 importlib-resources 5.12.0 itsdangerous 2.2.0 Jinja2 3.1.5 joblib 1.4.2 kiwisolver 1.4.8 langcodes 3.5.0 language_data 1.3.0 lit 18.1.8 lxml 6.0.0 marisa-trie 1.2.1 Markdown 3.7 markdown-it-py 3.0.0 MarkupSafe 3.0.2 matplotlib 3.3.4 mdurl 0.1.2 mpi4py 4.1.0 mpmath 1.3.0 msgpack 1.1.1 multidict 6.6.4 multiprocess 0.70.16 murmurhash 1.0.12 narwhals 2.1.2 natsort 8.4.0 networkx 3.4.2 ninja 1.13.0 numpy 1.21.5 nvidia-cublas-cu11 11.11.3.6 nvidia-cuda-cupti-cu11 11.8.87 nvidia-cuda-cupti-cu12 12.6.80 nvidia-cuda-nvrtc-cu11 11.8.89 nvidia-cuda-nvrtc-cu12 12.6.77 nvidia-cuda-runtime-cu11 11.8.89 nvidia-cuda-runtime-cu12 12.6.77 nvidia-cudnn-cu11 8.7.0.84 nvidia-cufft-cu11 10.9.0.58 nvidia-cufile-cu12 1.11.1.6 nvidia-curand-cu11 10.3.0.86 nvidia-curand-cu12 10.3.7.77 nvidia-cusolver-cu11 11.4.1.48 nvidia-cusparse-cu11 11.7.5.86 nvidia-cusparselt-cu12 0.6.3 nvidia-ml-py 13.580.65 nvidia-nccl-cu11 2.19.3 nvidia-nccl-cu12 2.26.2 nvidia-nvjitlink-cu12 12.6.85 nvidia-nvtx-cu11 11.8.86 nvidia-nvtx-cu12 12.6.77 orjson 3.10.16 packaging 24.2 pandas 2.0.3 pathlib_abc 0.1.1 pathy 0.11.0 Pillow 9.2.0 pip 25.0 plotly 6.3.0 preshed 3.0.9 propcache 0.3.2 protobuf 5.29.3 psutil 7.0.0 py-cpuinfo 9.0.0 pyarrow 21.0.0 pycryptodomex 3.23.0 pydantic 2.11.7 pydantic_core 2.33.2 pydub 0.25.1 Pygments 2.19.1 pyparsing 3.2.1 PySocks 1.7.1 python-dateutil 2.9.0.post0 python-multipart 0.0.20 pytz 2025.1 PyYAML 6.0 regex 2024.11.6 requests 2.32.3 rich 13.9.4 ruff 0.11.2 safehttpx 0.1.6 safetensors 0.6.2 scikit-learn 1.6.1 scipy 1.10.1 semantic-version 2.10.0 sentencepiece 0.2.0 setuptools 75.8.0 shellingham 1.5.4 six 1.17.0 smart-open 6.4.0 smplx 0.1.28 sniffio 1.3.0 socksio 1.0.0 soupsieve 2.6 spacy 3.8.7 spacy-legacy 3.0.12 spacy-loggers 1.0.5 srsly 2.5.1 starlette 0.46.1 sympy 1.13.3 tensorboard 2.18.0 tensorboard-data-server 0.7.2 thinc 8.3.6 threadpoolctl 3.5.0 tiktoken 0.11.0 tokenizers 0.21.4 tomlkit 0.13.2 torch 2.2.0+cu118 torch-tb-profiler 0.4.3 torchaudio 2.2.0+cu118 torchdata 0.7.1 torchdiffeq 0.2.5 torchtext 0.17.0 torchvision 0.17.0+cu118 tornado 6.4.2 tqdm 4.67.1 transformers 4.55.3 trimesh 4.6.1 triton 2.2.0 typer 0.15.2 typing_extensions 4.14.1 typing-inspection 0.4.0 tzdata 2025.1 urllib3 2.3.0 uvicorn 0.34.0 wasabi 0.10.1 wcwidth 0.2.13 weasel 0.4.1 websockets 15.0.1 Werkzeug 3.1.3 wheel 0.45.1 xxhash 3.5.0 yarl 1.20.1 zipp 3.21.0 zstandard 0.23.0
Oh, maybe it’s because I didn’t install flash-attn. I used the following code to avoid tensor nan.
class ModernBertModelWrapper(nn.Module):
def __init__(self, bert_version, max_length=120):
super(ModernBertModelWrapper, self).__init__()
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True) # Force math implementation
from transformers import AutoTokenizer, ModernBertModel
from_pretrained = os.path.join('./deps', bert_version)
self.tokenizer = AutoTokenizer.from_pretrained(
from_pretrained,
local_files_only = True,
legacy=False
)
self.model = ModernBertModel.from_pretrained(
from_pretrained,
local_files_only=True
).eval()
for p in self.model.parameters():
p.requires_grad = False
self.max_length = max_length
self.embed_dim = self.model.config.hidden_size
@torch.no_grad()
def forward(self, raw_text):
device = next(self.parameters()).device
# enc = tokenizer(text_caption, return_tensors="pt")
enc = self.tokenizer(raw_text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length).to(device)
attention_mask = enc["attention_mask"].to(device).bool()
# forward pass through encoder only
with torch.no_grad():
encoded = self.model(**enc).last_hidden_state.detach() # (B, Nt, D)
return encoded, attention_mask