llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

Modern Bert Support

Open ryan-mangeno opened this issue 3 months ago • 43 comments

adding support to run granite embedding small, and it primarily pulls the modern bert architecture - https://huggingface.co/ibm-granite/granite-embedding-small-english-r2, currently working on it still, havent figured out the pre-tokenizer type or if I need to impliment it, also for the ubatch size the assert fails in llama-graph.cpp, hacked it to accept ubatch size of 1 for testing, but it seems to keep failing there and not sure why,

if I comment out of the line in llama-graph.cpp

assert(!ubatch.equal_seqs());

then it works

ryan-mangeno avatar Aug 28 '25 17:08 ryan-mangeno

@gabe-l-hart thanks in advance :)

ryan-mangeno avatar Aug 28 '25 17:08 ryan-mangeno

@gabe-l-hart thanks in advance :)

also realizing this a little late haha, but should I be changing all of the modern bert stuff to a granite embedding macro like LLM_ARCH_GRANITE_EMBD or keep it as is

ryan-mangeno avatar Aug 28 '25 17:08 ryan-mangeno

You may want to check out an earlier attempt at ModernBert in #14014

CISC avatar Aug 28 '25 17:08 CISC

Thanks for getting this together @ryan-mangeno and thanks for pointing out the previous work @CISC. Ryan, let me know if/when you've looked over that PR and found anything to fix and I'll take a pass at review.

gabe-l-hart avatar Aug 28 '25 17:08 gabe-l-hart

also realizing this a little late haha, but should I be changing all of the modern bert stuff to a granite embedding macro like LLM_ARCH_GRANITE_EMBD or keep it as is

In general, we want to keep things as generic as possible, so since this uses the ModernBertModel architecture from transformers, it's best to keep the implementation here similarly robust unless there's a concrete reason to subset the transformers architecture to just work for granite (eg there's some non-trivial code path in the transformers version that would make sense as a separate architecture).

gabe-l-hart avatar Aug 28 '25 17:08 gabe-l-hart

Thanks for getting this together @ryan-mangeno and thanks for pointing out the previous work @CISC. Ryan, let me know if/when you've looked over that PR and found anything to fix and I'll take a pass at review.

will do

ryan-mangeno avatar Aug 28 '25 19:08 ryan-mangeno

@gabe-l-hart im looking into modern berts research paper, I cant find a mention of symmetric sliding window attention but rather local sliding window attention so I am going to opt to use LLAMA_SWA_TYPE_LOCAL versus LLAMA_SWA_TYPE_SYMMETRIC used in the previous attempt. It also uses global attention every third layer so I am going to implement this stuff and then it should be ready for a review :)

ryan-mangeno avatar Sep 03 '25 17:09 ryan-mangeno

@ryan-mangeno That sounds good! I haven't unpacked any of those mechanics myself, but can try to get into it if you get stuck.

gabe-l-hart avatar Sep 03 '25 18:09 gabe-l-hart

@ryan-mangeno That sounds good! I haven't unpacked any of those mechanics myself, but can try to get into it if you get stuck.

ok 👍 , made some changes but not sure if its fully ready yet, I will ping you when I think its ready if thats ok

ryan-mangeno avatar Sep 03 '25 18:09 ryan-mangeno

status update - I found out that modern bert uses an alternating rope method , per https://arxiv.org/pdf/2412.13663

In ModernBERT, every third layer employs global
attention with a RoPE theta of 160,000 and the
remaining layers use a 128 token, local sliding window attention with a RoPE theta of 10,000.

I am currently figuring out how to implement this

ryan-mangeno avatar Sep 04 '25 22:09 ryan-mangeno

status update - I found out that modern bert uses an alternating rope method , per arxiv.org/pdf/2412.13663

In ModernBERT, every third layer employs global
attention with a RoPE theta of 160,000 and the
remaining layers use a 128 token, local sliding window attention with a RoPE theta of 10,000.

I am currently figuring out how to implement this

IIUC this matches how sliding window attention is handled for Gemma3: https://github.com/ggml-org/llama.cpp/blob/5d6688de08e73acc2532d668380801ed79d704eb/src/llama-model.cpp#L1106

ehoogeveen-medweb avatar Sep 05 '25 06:09 ehoogeveen-medweb

Gemma3

hey, thanks for the heads up! I noticed in the gemma3 implementation that swa is setup

// TODO: is causal == true correct? might need some changes
        auto * inp_attn = build_attn_inp_kv_unified_iswa();

but it is not handled when looping over the layers in

llm_build_gemma3_iswa

is this intentional, is the actual logic of the swa configuration happening elsewhere?

ryan-mangeno avatar Sep 06 '25 17:09 ryan-mangeno

There's some SWA configuration in the code I linked, starting here: https://github.com/ggml-org/llama.cpp/blob/5d6688de08e73acc2532d668380801ed79d704eb/src/llama-model.cpp#L1103

But I'm not sure whether that answers your question, as this PR already seems to set a similar configuration for the new architecture... unfortunately I'm not a true expert, I just remembered noticing that hardcoded RoPE base and scale for Gemma3 before.

ehoogeveen-medweb avatar Sep 06 '25 17:09 ehoogeveen-medweb

have been working on the alternating attention, having some issues creating the local window and getting mostly non matching dim errors like

/Users/ryanmangeno/Projects/gits/llama.cpp/ggml/src/ggml.c:3901: GGML_ASSERT(a->ne[2] == b->ne[0]) failed

ryan-mangeno avatar Sep 08 '25 01:09 ryan-mangeno

currently failing on this line

K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr,
                                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                                ext_factor, attn_factor, beta_fast, beta_slow);

ryan-mangeno avatar Sep 08 '25 01:09 ryan-mangeno

There's some SWA configuration in the code I linked, starting here:

https://github.com/ggml-org/llama.cpp/blob/5d6688de08e73acc2532d668380801ed79d704eb/src/llama-model.cpp#L1103

But I'm not sure whether that answers your question, as this PR already seems to set a similar configuration for the new architecture... unfortunately I'm not a true expert, I just remembered noticing that hardcoded RoPE base and scale for Gemma3 before.

alright, and yes its been pretty helpful ive been using it as a refrence to implement swa for modern bert, thanks !!! :)

ryan-mangeno avatar Sep 08 '25 19:09 ryan-mangeno

sorry if this has been a little slow, the alternating attention mechanism has been a little tough to implement but hoping to get it fixed soon

ryan-mangeno avatar Sep 11 '25 20:09 ryan-mangeno

@gabe-l-hart I believe this should be ready for review whenever your available to check it out :)

ryan-mangeno avatar Sep 12 '25 16:09 ryan-mangeno

Awesome, thanks for your hard work on this @ryan-mangeno . I'll look it over soon!

gabe-l-hart avatar Sep 12 '25 17:09 gabe-l-hart

@ryan-mangeno Two requests:

  1. Can you merge in master and resolve the conflicts (I can help if you get stuck)
  2. Can you share what you've been doing to compare outputs between this version and transformers?

gabe-l-hart avatar Sep 12 '25 21:09 gabe-l-hart

@ryan-mangeno Two requests:

  1. Can you merge in master and resolve the conflicts (I can help if you get stuck)
  2. Can you share what you've been doing to compare outputs between this version and transformers?

yes will get on that 👍

ryan-mangeno avatar Sep 13 '25 18:09 ryan-mangeno

@ryan-mangeno Two requests:

  1. Can you merge in master and resolve the conflicts (I can help if you get stuck)
  2. Can you share what you've been doing to compare outputs between this version and transformers?

yes will get on that 👍

here is the command I run on llama.cpp

./build/bin/llama-embedding \
    -m models/modernbert.gguf \
    -p "hello world" \
    --temp 0.0 \
    --repeat_penalty 1.0 \
    --top_k 0 \
    --top_p 1.0 \

and here is my script for hf

import torch
from transformers import AutoModel, AutoTokenizer

torch.manual_seed(0) 
torch.use_deterministic_algorithms(True)  
model_path = "ibm-granite/granite-embedding-small-english-r2"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)
model.eval()

input_queries = ["hello world"]

tokenized_queries = tokenizer(
    input_queries,
    padding=True,
    truncation=True,
    return_tensors="pt"
)

with torch.no_grad():
    outputs = model(**tokenized_queries)
    embedding = outputs.last_hidden_state[:, 0, :]  # CLS token

print("Embedding shape:", embedding.shape)
print("Embedding vector:", embedding)


ryan-mangeno avatar Sep 13 '25 19:09 ryan-mangeno

I also have a script for the cosine similarity between the two resulting emebeddings i get,

import numpy as np

def cosine_similarity(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    
    norm_v1 = np.linalg.norm(vec1)
    norm_v2 = np.linalg.norm(vec2)
    
    if norm_v1 == 0 or norm_v2 == 0:
        return 0.0
    
    similarity = dot_product / (norm_v1 * norm_v2)
    
    return similarity

hf_embds = np.array(<copy and paste tensor from hf output>)
llama_data_string = "< llama prints emebeddings without comma seperators so treat it as a string then split >"
llama_embds = np.array([float(i) for i in llama_data_string.split()])

print(cosine_similarity(llama_embds, hf_embds))

it currently prints

0.0502

so pretty low similarlity at its face value, still working through it and hoping to get better results

ryan-mangeno avatar Sep 13 '25 20:09 ryan-mangeno

Just an update, I think I might be getting bad results because I did not implement flash attention which is outlined in the modern bert research paper, I will try to update this

ryan-mangeno avatar Sep 26 '25 15:09 ryan-mangeno

Just an update, I think I might be getting bad results because I did not implement flash attention which is outlined in the modern bert research paper, I will try to update this

found out flash attention is a flag you can pass in when running model, results still not great so will keep trying to hack at it.

ryan-mangeno avatar Sep 26 '25 16:09 ryan-mangeno

to my knowledge since modern bert is an encoder that I shouldnt be using a kv cache and use,

auto * inp_attn = build_attn_inp_no_cache();

during the graph builld, but since modern bert uses swa, when input is set during

void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) 

this assert fails, and I am not really too sure how long this will take to implement if this a crucial step to the current implementation of modern bert

    GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");

ryan-mangeno avatar Oct 01 '25 20:10 ryan-mangeno

SWA support for cache-less context is not ready yet. For now use a SWA cache similar to llm_build_gemma_embedding_iswa and add a TODO to be fixed later.

ggerganov avatar Oct 02 '25 07:10 ggerganov

SWA support for cache-less context is not ready yet. For now use a SWA cache similar to llm_build_gemma_embedding_iswa and add a TODO to be fixed later.

ok will do, thank you so much!!

ryan-mangeno avatar Oct 04 '25 15:10 ryan-mangeno

Hey, wanted to see if this could be reviewed sometime. I am pretty sure I have gone through and added the corrected things, let me know of anything to change/add :))

ryan-mangeno avatar Oct 08 '25 18:10 ryan-mangeno

thanks for the insight and sugestions! I also added support to convert the modern bert base model to gguf

ryan-mangeno avatar Oct 10 '25 19:10 ryan-mangeno