candle icon indicating copy to clipboard operation
candle copied to clipboard

~2x slower than `Transformer` on cpu with `Bert` model

Open CrazyboyQCD opened this issue 1 year ago • 2 comments
trafficstars

OS: Windows 11 Model: maidalun1020/bce-embedding-base_v1

Command:

cargo run --features mkl --example bert --release -- --model-id maidalun1020/bce-embedding-base_v1 --use-pth

Candle(took ~15s):

    let s = std::fs::read_to_string(
        "test.txt",
    )?;

    // split text by length of 512
    let mut splited_strs = vec![];
    let mut cur_length = 0;
    let mut tmp_str = String::new();
    for ch in s.chars() {
        let l = ch.len_utf8();
        if cur_length + l <= 512 {
            cur_length += l;
        } else {
            splited_strs.push(tmp_str.drain(..).collect::<String>());
            cur_length = l;
        }
        tmp_str.push(ch);
    }

    let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
    let device = &model.device;
    let tokenizer = tokenizer
        .with_padding(Some(Default::default()))
        .with_truncation(Some(Default::default()))
        .map_err(E::msg)?;
    let mut tensors = vec![];
    for s in splited_strs {
        let e = tokenizer.encode(s, true).map_err(E::msg)?;
        let token_ids = Tensor::new(e.get_ids(), device)?.unsqueeze(0)?;
        let token_type_ids = token_ids.zeros_like()?;
        tensors.push([token_ids, token_type_ids]);
    }
 
    let start = std::time::Instant::now();
    for [token_ids, token_type_ids] in tensors {
        model.forward(&token_ids, &token_type_ids)?;
    }
    println!("Took {:?} ", start.elapsed());

Transfomer(took ~ 8s)

import time

from BCEmbedding import EmbeddingModel

# init embedding model
model = EmbeddingModel(model_name_or_path="maidalun1020/bce-embedding-base_v1")

f = open(
    "test.txt",
    encoding="utf8",
    mode="r",
).read()


# split text by the length of 512
def split_text(characters: List[str], length: int) -> List[str]:
    result: List[str] = []
    current_string = ""
    current_length = 0

    for char in characters:
        if current_length + (clen := len(char.encode("utf-8"))) <= length:
            current_string += char
            current_length += clen
        else:
            result.append(current_string)
            current_string = char
            current_length = clen

    if current_string:
        result.append(current_string)

    return result


(*sentences,) = f

sentences = split_text(sentences, 512)

st = time.time()

# extract embeddings
embeddings = model.encode(sentences)

print(time.time() - st)

text file: test.txt

CrazyboyQCD avatar May 22 '24 09:05 CrazyboyQCD

PyTorch is likely to use mkl by default so you probably want to enable it on the candle side too if you haven't done so already.

LaurentMazare avatar May 22 '24 09:05 LaurentMazare

PyTorch is likely to use mkl by default so you probably want to enable it on the candle side too if you haven't done so already.

Ops, forgot to add the command I used, I add the mkl feature and it did reduce ~3s from 18s.

CrazyboyQCD avatar May 22 '24 10:05 CrazyboyQCD