rust-bert icon indicating copy to clipboard operation
rust-bert copied to clipboard

Please expose tonekizer params on models where `forward_t` is exposed

Open HarryCaveMan opened this issue 1 year ago • 0 comments

If I want to use the SequenceClassifier pipeline for something like reranking, I am (sort of) able to do so using the exposed forward_t method. The problem is that I will need to first encode the inputs using the model's tokenizer. I can get a ref to the tokenizer using get_tokenizer, but if I want to pass in tokenizer params (IE max_len and device) to tokenizer.tokenize, I cannot get them from the SequenceClassificationModel, because they are private fields and there are not any get methods like there are for the tokenizer itself.

Alternatively, you could add a method to wrap calls to SequenceClassificationModel.tokenizer.tokenize and pass these parameter in from the model instance.

HarryCaveMan avatar Oct 20 '23 17:10 HarryCaveMan