ModernBERT icon indicating copy to clipboard operation
ModernBERT copied to clipboard

ModernBertModel works on the CPU but fails on the GPU

Open rudigunn opened this issue 11 months ago • 8 comments

Hello everyone,

My problem is that ModernBertModel fails to return a valid output when I use the GPU instead of the CPU. The following code returns a valid output:

import torch
from transformers import AutoTokenizer, ModernBertModel

model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ModernBertModel.from_pretrained(model_id)

device = torch.device("cpu")
model.to(device)

texts = ["The capital of France is Paris.", "The capital of Germany is Berlin."]

inputs = tokenizer(
    text=texts,
    add_special_tokens=True,
    padding='max_length',
    truncation=True,
    max_length=768,
    return_attention_mask=True,
    return_tensors='pt' 
)

input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)

outputs = model(input_ids=input_ids, attention_mask=attention_mask)

print(outputs.last_hidden_state)

Output:

tensor([[[ 0.2510, -0.8900, -0.7447,  ..., -0.6569,  0.2809, -0.5663],
         [ 0.1292,  0.0536,  0.2478,  ...,  0.1400, -0.1059,  0.0981],
         [-0.0945, -1.2089, -0.5087,  ..., -0.0810,  1.4614, -0.1214],
         ...,
         [ 1.5802, -0.2266,  0.8008,  ..., -0.8563, -0.0378, -0.6842],
         [ 1.6365, -0.2077,  0.7667,  ..., -0.8660, -0.0537, -0.6460],
         [ 1.6404, -0.1780,  0.7846,  ..., -0.8497, -0.0268, -0.6155]],

        [[ 0.3872, -0.9977, -0.8920,  ..., -0.7293,  0.5094, -0.5080],
         [-0.1917, -0.8092, -0.3774,  ..., -1.0475, -0.4196,  0.1802],
         [-0.0937, -1.1293, -0.8068,  ...,  0.4551,  1.5275, -0.0922],
         ...,
         [ 1.7813,  0.2581,  0.6624,  ..., -1.0199, -0.1711, -1.1627],
         [ 1.8317,  0.3041,  0.6434,  ..., -1.0328, -0.1824, -1.1392],
         [ 1.8517,  0.3272,  0.6883,  ..., -0.9966, -0.1606, -1.1124]]],
       grad_fn=<NativeLayerNormBackward0>)

But when I switch to the GPU I get a tensor with NaNs:

import torch
from transformers import AutoTokenizer, ModernBertModel

model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ModernBertModel.from_pretrained(model_id)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

texts = ["The capital of France is Paris.", "The capital of Germany is Berlin."]

inputs = tokenizer(
    text=texts,
    add_special_tokens=True,
    padding='max_length',
    truncation=True,
    max_length=768,
    return_attention_mask=True,
    return_tensors='pt' 
)

input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)

outputs = model(input_ids=input_ids, attention_mask=attention_mask)

print(outputs.last_hidden_state)

Output:

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward0>)

Do you have an idea what the problem might be?

rudigunn avatar Jan 08 '25 11:01 rudigunn

I am having this exact issue. Fails on CUDA, works on MPS:

Input IDs: tensor([[50281,   510, 28983,  ..., 50283, 50283, 50283],
        [50281,   688, 13524,  ..., 50283, 50283, 50283],
        [50281,  5707, 25280,  ..., 50283, 50283, 50283],
        ...,
        [50281, 20411, 10833,  ..., 50283, 50283, 50283],
        [50281, 10772, 41316,  ..., 50283, 50283, 50283],
        [50281,  4749, 35982,  ..., 50283, 50283, 50283]], device='cuda:0')
Attention mask: tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')
Sliding window mask: None
Position IDs: None
Inputs embeds: None
Outputs:
          ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        ...,

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], 

KendallPark avatar Jan 27 '25 17:01 KendallPark

Update: I suspected that this had to do with the attn_mechanism. By default it was running as sdpa. I decided to run with flash_attention_2 which launched some additional installs to resolve dependencies (flash-attn andtriton). Then I downgraded to Python 3.11 and reinstalled my env because I was getting "Dynamo is not supported on Python 3.12+."

Now it works with all attn_impl's ( sdpa, flash_attention_2, and eager). There's probably some package difference between 3.11 and 3.12 causing the issue.

KendallPark avatar Jan 27 '25 18:01 KendallPark

I am facing the same problem. It works with batch size=1 on cuda, but anything larger returns zero. Any solution on this one?

farrokhsiar avatar Feb 13 '25 21:02 farrokhsiar

@KendallPark could you post versions of torch, flash-attn and triton you used successfully?

DanielWeitzenfeld avatar May 08 '25 21:05 DanielWeitzenfeld

same issue

eanson023 avatar Aug 21 '25 17:08 eanson023

torch 2.2 without falsh-attn

eanson023 avatar Aug 21 '25 17:08 eanson023

absl-py 2.1.0 aiofiles 23.2.1 aiohappyeyeballs 2.6.1 aiohttp 3.12.15 aiosignal 1.4.0 annotated-types 0.7.0 anyio 4.9.0 async-timeout 5.0.1 attrs 25.3.0 beautifulsoup4 4.13.3 blinker 1.9.0 blis 1.3.0 blobfile 3.0.0 catalogue 2.0.10 certifi 2025.1.31 charset-normalizer 3.4.1 chumpy 0.70 click 8.1.8 clip 1.0 cloudpathlib 0.21.1 cmake 3.31.4 confection 0.1.5 cycler 0.12.1 cymem 2.0.11 Cython 3.0.12 datasets 4.0.0 dill 0.3.8 einops 0.6.1 en-core-web-sm 3.2.0 exceptiongroup 1.2.2 fastapi 0.115.12 ffmpy 0.3.1 filelock 3.17.0 Flask 3.1.0 frozenlist 1.7.0 fsspec 2025.3.0 ftfy 6.1.1 gdown 5.2.0 gradio 5.23.1 gradio_client 1.8.0 groovy 0.1.2 grpcio 1.54.2 h11 0.14.0 hf-xet 1.1.7 hjson 3.1.0 httpcore 1.0.7 httpx 0.28.1 huggingface-hub 0.34.4 idna 3.10 imageio 2.37.0 imageio-ffmpeg 0.6.0 importlib-metadata 5.0.0 importlib-resources 5.12.0 itsdangerous 2.2.0 Jinja2 3.1.5 joblib 1.4.2 kiwisolver 1.4.8 langcodes 3.5.0 language_data 1.3.0 lit 18.1.8 lxml 6.0.0 marisa-trie 1.2.1 Markdown 3.7 markdown-it-py 3.0.0 MarkupSafe 3.0.2 matplotlib 3.3.4 mdurl 0.1.2 mpi4py 4.1.0 mpmath 1.3.0 msgpack 1.1.1 multidict 6.6.4 multiprocess 0.70.16 murmurhash 1.0.12 narwhals 2.1.2 natsort 8.4.0 networkx 3.4.2 ninja 1.13.0 numpy 1.21.5 nvidia-cublas-cu11 11.11.3.6 nvidia-cuda-cupti-cu11 11.8.87 nvidia-cuda-cupti-cu12 12.6.80 nvidia-cuda-nvrtc-cu11 11.8.89 nvidia-cuda-nvrtc-cu12 12.6.77 nvidia-cuda-runtime-cu11 11.8.89 nvidia-cuda-runtime-cu12 12.6.77 nvidia-cudnn-cu11 8.7.0.84 nvidia-cufft-cu11 10.9.0.58 nvidia-cufile-cu12 1.11.1.6 nvidia-curand-cu11 10.3.0.86 nvidia-curand-cu12 10.3.7.77 nvidia-cusolver-cu11 11.4.1.48 nvidia-cusparse-cu11 11.7.5.86 nvidia-cusparselt-cu12 0.6.3 nvidia-ml-py 13.580.65 nvidia-nccl-cu11 2.19.3 nvidia-nccl-cu12 2.26.2 nvidia-nvjitlink-cu12 12.6.85 nvidia-nvtx-cu11 11.8.86 nvidia-nvtx-cu12 12.6.77 orjson 3.10.16 packaging 24.2 pandas 2.0.3 pathlib_abc 0.1.1 pathy 0.11.0 Pillow 9.2.0 pip 25.0 plotly 6.3.0 preshed 3.0.9 propcache 0.3.2 protobuf 5.29.3 psutil 7.0.0 py-cpuinfo 9.0.0 pyarrow 21.0.0 pycryptodomex 3.23.0 pydantic 2.11.7 pydantic_core 2.33.2 pydub 0.25.1 Pygments 2.19.1 pyparsing 3.2.1 PySocks 1.7.1 python-dateutil 2.9.0.post0 python-multipart 0.0.20 pytz 2025.1 PyYAML 6.0 regex 2024.11.6 requests 2.32.3 rich 13.9.4 ruff 0.11.2 safehttpx 0.1.6 safetensors 0.6.2 scikit-learn 1.6.1 scipy 1.10.1 semantic-version 2.10.0 sentencepiece 0.2.0 setuptools 75.8.0 shellingham 1.5.4 six 1.17.0 smart-open 6.4.0 smplx 0.1.28 sniffio 1.3.0 socksio 1.0.0 soupsieve 2.6 spacy 3.8.7 spacy-legacy 3.0.12 spacy-loggers 1.0.5 srsly 2.5.1 starlette 0.46.1 sympy 1.13.3 tensorboard 2.18.0 tensorboard-data-server 0.7.2 thinc 8.3.6 threadpoolctl 3.5.0 tiktoken 0.11.0 tokenizers 0.21.4 tomlkit 0.13.2 torch 2.2.0+cu118 torch-tb-profiler 0.4.3 torchaudio 2.2.0+cu118 torchdata 0.7.1 torchdiffeq 0.2.5 torchtext 0.17.0 torchvision 0.17.0+cu118 tornado 6.4.2 tqdm 4.67.1 transformers 4.55.3 trimesh 4.6.1 triton 2.2.0 typer 0.15.2 typing_extensions 4.14.1 typing-inspection 0.4.0 tzdata 2025.1 urllib3 2.3.0 uvicorn 0.34.0 wasabi 0.10.1 wcwidth 0.2.13 weasel 0.4.1 websockets 15.0.1 Werkzeug 3.1.3 wheel 0.45.1 xxhash 3.5.0 yarl 1.20.1 zipp 3.21.0 zstandard 0.23.0

eanson023 avatar Aug 21 '25 17:08 eanson023

Oh, maybe it’s because I didn’t install flash-attn. I used the following code to avoid tensor nan.

class ModernBertModelWrapper(nn.Module):
    def __init__(self, bert_version, max_length=120):
        super(ModernBertModelWrapper, self).__init__()

        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)  # Force math implementation

        from transformers import AutoTokenizer, ModernBertModel

        from_pretrained = os.path.join('./deps', bert_version)
        self.tokenizer = AutoTokenizer.from_pretrained(
            from_pretrained,
            local_files_only = True,
            legacy=False
        )
        
        self.model = ModernBertModel.from_pretrained(
            from_pretrained,
            local_files_only=True
        ).eval()
        for p in self.model.parameters():
            p.requires_grad = False

        self.max_length = max_length
        self.embed_dim = self.model.config.hidden_size 


    @torch.no_grad()
    def forward(self, raw_text):
        device = next(self.parameters()).device
        # enc = tokenizer(text_caption, return_tensors="pt")
        enc = self.tokenizer(raw_text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length).to(device)
        attention_mask = enc["attention_mask"].to(device).bool()

        # forward pass through encoder only
        with torch.no_grad():
            encoded = self.model(**enc).last_hidden_state.detach()  # (B, Nt, D)

        return encoded, attention_mask

eanson023 avatar Aug 21 '25 17:08 eanson023