candle
candle copied to clipboard
~2x slower than `Transformer` on cpu with `Bert` model
trafficstars
OS: Windows 11 Model: maidalun1020/bce-embedding-base_v1
Command:
cargo run --features mkl --example bert --release -- --model-id maidalun1020/bce-embedding-base_v1 --use-pth
Candle(took ~15s):
let s = std::fs::read_to_string(
"test.txt",
)?;
// split text by length of 512
let mut splited_strs = vec![];
let mut cur_length = 0;
let mut tmp_str = String::new();
for ch in s.chars() {
let l = ch.len_utf8();
if cur_length + l <= 512 {
cur_length += l;
} else {
splited_strs.push(tmp_str.drain(..).collect::<String>());
cur_length = l;
}
tmp_str.push(ch);
}
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let tokenizer = tokenizer
.with_padding(Some(Default::default()))
.with_truncation(Some(Default::default()))
.map_err(E::msg)?;
let mut tensors = vec![];
for s in splited_strs {
let e = tokenizer.encode(s, true).map_err(E::msg)?;
let token_ids = Tensor::new(e.get_ids(), device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
tensors.push([token_ids, token_type_ids]);
}
let start = std::time::Instant::now();
for [token_ids, token_type_ids] in tensors {
model.forward(&token_ids, &token_type_ids)?;
}
println!("Took {:?} ", start.elapsed());
Transfomer(took ~ 8s)
import time
from BCEmbedding import EmbeddingModel
# init embedding model
model = EmbeddingModel(model_name_or_path="maidalun1020/bce-embedding-base_v1")
f = open(
"test.txt",
encoding="utf8",
mode="r",
).read()
# split text by the length of 512
def split_text(characters: List[str], length: int) -> List[str]:
result: List[str] = []
current_string = ""
current_length = 0
for char in characters:
if current_length + (clen := len(char.encode("utf-8"))) <= length:
current_string += char
current_length += clen
else:
result.append(current_string)
current_string = char
current_length = clen
if current_string:
result.append(current_string)
return result
(*sentences,) = f
sentences = split_text(sentences, 512)
st = time.time()
# extract embeddings
embeddings = model.encode(sentences)
print(time.time() - st)
text file: test.txt
PyTorch is likely to use mkl by default so you probably want to enable it on the candle side too if you haven't done so already.
PyTorch is likely to use mkl by default so you probably want to enable it on the candle side too if you haven't done so already.
Ops, forgot to add the command I used, I add the mkl feature and it did reduce ~3s from 18s.