FlagEmbedding
FlagEmbedding copied to clipboard
bge-m3,vllm和FlagEmbedding,同一query向量不一样
因为需要api,Online Serving,就看下vllm,发现支持XLMRobertaModel,用它加载bge-m3模型
vllm代码
from vllm import LLM
prompts = ['精通excel', '银行项目', '市场营销']
model = LLM(
model="/bge/bge-m3",
task="embed",
enforce_eager=True,
)
outputs = model.embed(prompts)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] +
", ...]") if len(embeds) > 16 else embeds)
print(f"Prompt: {prompt!r} | "
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
输出结果
Prompt: '精通excel' | Embeddings: [-0.032440185546875, 0.005889892578125, -0.0306549072265625, -0.001209259033203125, -0.0201568603515625, -0.02447509765625, 0.0341796875, -0.0017910003662109375, 0.005279541015625, -0.0124664306640625, -0.006496429443359375, -0.0012645721435546875, 0.0028133392333984375, 0.01546478271484375, 0.0235748291015625, -0.0225982666015625, ...] (size=1024)
Prompt: '银行项目' | Embeddings: [-0.030975341796875, -0.023101806640625, -0.035552978515625, 9.143352508544922e-05, -0.01654052734375, 0.0008797645568847656, 0.024444580078125, 0.005847930908203125, 0.040069580078125, 0.006481170654296875, 0.0401611328125, 0.0143890380859375, -0.012298583984375, -0.00902557373046875, 0.02740478515625, -0.026580810546875, ...] (size=1024)
Prompt: '市场营销' | Embeddings: [-0.06597900390625, -0.00835418701171875, -0.0174407958984375, -0.0255279541015625, -0.0012760162353515625, 0.05255126953125, -0.026153564453125, 0.007213592529296875, -0.0124969482421875, -0.00920867919921875, -0.029083251953125, -0.0008821487426757812, -0.01201629638671875, -0.00135040283203125, 0.05426025390625, -0.00839996337890625, ...] (size=1024)
用FlagEmbedding的代码
from FlagEmbedding import BGEM3FlagModel
import numpy as np
model_path = "/bge/bge-m3"
model = BGEM3FlagModel(model_path,
use_fp16=True,
devices=['cuda:0'])
queries = ['精通excel', '银行项目', '市场营销']
q_embeddings = model.encode(queries)['dense_vecs']
for prompt, embeds in zip(queries, q_embeddings):
embeds_trimmed = ((str(embeds[:16])[:-1] +
", ...]") if len(embeds) > 16 else embeds)
print(f"Prompt: {prompt!r} | "
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
输出结果
Prompt: '精通excel' | Embeddings: [-0.03238, 0.005817, -0.03069, -0.001201, -0.02011, -0.02441, 0.03424, -0.001688, 0.005264, -0.01248, -0.006535, -0.001193, 0.002863, 0.01545, 0.02351, -0.0227, ...] (size=1024)
Prompt: '银行项目' | Embeddings: [-0.03104, -0.02313, -0.0356, 0.0001503, -0.01656, 0.000881, 0.02432, 0.00588, 0.04, 0.006462, 0.04016, 0.01441, -0.01224, -0.00901, 0.02747, -0.0266, ...] (size=1024)
Prompt: '市场营销' | Embeddings: [-0.0658, -0.008354, -0.0174, -0.02557, -0.001265, 0.0525, -0.02612, 0.00727, -0.01251, -0.00919, -0.02902, -0.0008845, -0.01208, -0.001312, 0.0543, -0.00839, ...] (size=1024)
结果不是完全一样的,请问有人知道是为什么吗?
可以检查一下,是否normalize embedding,以及是否用了[CLS] token的embedding作为最终输出