rust-bert
rust-bert copied to clipboard
How to run custom RobertaForSequenceClassification model
Hi guys, I'm trying to load a custom model with RobertaForSequenceClassification but I don't know how to "predict". I assume I have to use forward_t method, but I'm not sure how to and if that's the case. All I want to do is insert a text and receive a prediction. Bellow you can see my code.
let config_resource = LocalResource {
local_path: PathBuf::from("config.json"),
};
let vocab_resource = LocalResource {
local_path: PathBuf::from("vocab.json"),
};
let merges_resource = LocalResource {
local_path: PathBuf::from("merges.txt"),
};
let weights_resource = LocalResource {
local_path: PathBuf::from("rust_model.ot"),
};
let config_path = config_resource.get_local_path().unwrap();
let vocab_path = vocab_resource.get_local_path().unwrap();
let merges_path = merges_resource.get_local_path().unwrap();
let weights_path = weights_resource.get_local_path().unwrap();
let vocab = RobertaVocab::from_file(vocab_path.to_str().unwrap()).unwrap();
let merges = BpePairVocab::from_file(merges_path.to_str().unwrap()).unwrap();
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_existing_vocab_and_merges(
vocab,
merges,
true,
true,
);
let Input = ["my_text"];
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let config = RobertaConfig::from_file(config_path);
let roberta = RobertaForSequenceClassification::new(&vs.root(), &config);
vs.load(weights_path).unwrap();
//TODO: predict Input variable
Hello @antonioualex ,
For sequence classification, the easiest would be to use pipelines that take care of tokenization, batching and padding for you. Can you please check the example at https://github.com/guillaume-be/rust-bert/blob/master/examples/sequence_classification.rs that illustrates how to do this with a set of defaults. You can update the configuration to use a custom model instead (see for example https://github.com/guillaume-be/rust-bert/blob/master/examples/sentiment_analysis_fnet.rs)
Please let me know if this helps.