rust-bert
rust-bert copied to clipboard
Use SimCSE model for sentence embeddings
Hi!
I am interested in using the SimCSE model to get sentence embeddings, as its embeddings have been shown to significantly outperform SBERT (which is currently provided by this library) on various tasks.
I have created a rust_model.ot as described in the README, and copied over the local configs for princeton-nlp/sup-simcse-bert-base-uncased. However, I ran into the problem that the sentence embedding pipeline in rust-bert expects various config files that are provided by SentenceTransformers that aren't provided by SimCSE. I tried copying the missing configs from this model, and the pipeline works now, but it produces a different embedding output than the Python version.
I think the output differences are due to different pooling strategies, but I am not sure how to make the Rust version behave like the Python version.
Python code I am using:
import torch
from transformers import AutoModel, AutoTokenizer
# Import our models. The package will take care of downloading the models automatically
tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
# Tokenize input texts
texts = ["Hello, I am a sentence!"]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# Get the embeddings
with torch.no_grad():
embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
print(embeddings)
Rust code:
use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Loading the model...");
let model = SentenceEmbeddingsBuilder::local("bert-base-uncased").create_model()?;
let sents = &["Hello, I am a sentence!"];
let embeddings = model.encode(sents)?;
println!("{embeddings:#?}");
Ok(())
}
Is there a way to adapt the sentence_embeddings pipeline for SimCSE? Both use the Bert model underlyingly, so I hope there is a more or less straightforward path for this.
Update: I was able to produce identical output to the Python version in Rust using this code:
use std::path::PathBuf;
use rust_bert::bert::{BertConfig, BertForSentenceEmbeddings};
use rust_bert::resources::{LocalResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let config_resource = LocalResource {
local_path: PathBuf::from("bert-base-uncased/config.json"),
};
let vocab_resource = LocalResource {
local_path: PathBuf::from("bert-base-uncased/vocab.txt"),
};
let weights_resource = LocalResource {
local_path: PathBuf::from("bert-base-uncased/rust_model.ot"),
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = BertConfig::from_file(config_path);
let model = BertForSentenceEmbeddings::new(&vs.root(), &config);
vs.load(weights_path)?;
let tokenized_input = tokenizer.encode_list(
&["Hello, I am a sentence!"],
128,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap_or(0);
let pad_token_id = 0;
let tokens_ids = tokenized_input
.into_iter()
.map(|input| {
let mut token_ids = input.token_ids;
token_ids.extend(vec![pad_token_id; max_len - token_ids.len()]);
token_ids
})
.collect::<Vec<_>>();
let tokens_masks = tokens_ids
.iter()
.map(|input| {
Tensor::of_slice(
&input
.iter()
.map(|&e| i64::from(e != pad_token_id))
.collect::<Vec<_>>(),
)
})
.collect::<Vec<_>>();
let tokens_ids = tokens_ids
.into_iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let tokens_ids = Tensor::stack(&tokens_ids, 0);
let tokens_masks = Tensor::stack(&tokens_masks, 0);
let output = no_grad(|| {
model
.forward_t(
Some(&tokens_ids),
Some(&tokens_masks),
None,
None,
None,
None,
None,
false,
)
.unwrap()
})
.pooled_output
.unwrap()
.get(0);
let embeddings: Vec<f32> = Vec::from(output);
dbg!(embeddings);
Ok(())
}
@guillaume-be would you be interested in having this added as a pipeline?
Hello @anna-hope ,
Thank you for testing SimCSE with this library - and glad you could make it work. It looks like the SentenceEmbeddings pipeline should be compatible with the model you are loading -- have you tested using this with the local model builder as well? If not - I believe only minor changes would be required to be able to generate embeddings from what looks to be a bare BertModel and skip all the additional layers from SentenceEmbeddings encoders.
The model would be a welcome addition to the library - it would be great if you could open a PR on the Huggingface model hub to upload the Rust-based weights as well.
Thank you,
@guillaume-be
It looks like the SentenceEmbeddings pipeline should be compatible with the model you are loading -- have you tested using this with the local model builder as well?
I believe that's what I tried originally, which didn't work because of incompatible configuration formats between SBERT and SimCSE. Unless the "local model builder" would be something different from SentenceEmbeddingsBuilder::local("bert-base-uncased")?
I believe only minor changes would be required to be able to generate embeddings from what looks to be a bare
BertModeland skip all the additional layers fromSentenceEmbeddingsencoders.
The challenge I am facing with integrating the SimCSE models with the SentenceEmbeddings pipeline is that the pipeline builder expects modules.json, which SBERT has and SimCSE doesn't, and which in turn is used to set up the pooling config, which the SimCSE model doesn't use. There are perhaps other points where the two are incompatible as well. At the same time, I agree that the bulk of the pipeline could be reused.
Would it be worthwhile to introduce additional branching logic to have the SimCSE model be handled by the same pipeline, at the cost of complicating the code base?