rust-bert
rust-bert copied to clipboard
Does the marian model have a method like huggingface generate?
Using pipline is slower than using python huggingface library transformers generate function, when the model file is loaded, in using CPU envierment.
The pipeline should not be slower than the Python equivalent on the same device. If you are using a CUDA-enabled GPU, please ensure it is used for both frameworks. The Marian model exposes a generate method via the MarianGenerator struct and the LanguageGenerator trait.
The pipeline should not be slower than the Python equivalent on the same device. If you are using a CUDA-enabled GPU, please ensure it is used for both frameworks. The Marian model exposes a generate method via the MarianGenerator struct and the LanguageGenerator trait.
Thanks for your replay. When the Marian model calls the pipeline using the GPU (specify the use of GPU Device::Cuda(3), observe that nvidia-smi is occupied when running rust programs), Slower than python calls to the huggingface library in docker environments without a GPU.
Python code (cpu: 0.3s Average of 100 visits)
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MarianTokenizer, MarianMTModel
# Load the Marian model and tokenizer
model_name = "Helsinki-NLP/opus-mt-zh-en" # Replace with your desired model
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
app = FastAPI()
class InputData(BaseModel):
text: str
@app.post("/v1/predict", response_model=dict(generation_text=str))
async def predict(input_data: InputData):
# Translate the input text
input_text = input_data.text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
translation_ids = model.generate(input_ids, max_length=50, num_return_sequences=1)
generation_text = tokenizer.decode(translation_ids[0], skip_special_tokens=True)
return {"generation_text": generation_text}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Rust code (cpu: 0.71s gpu: 0.38s Average of 100 visits)
extern crate anyhow;
use actix_web::{error, get, post, web,
http::{header::ContentType, StatusCode},
App, HttpResponse, Responder, Result, HttpRequest,HttpServer};
use serde::Serialize;
use serde::Deserialize;
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider, Resource, LocalResource};
use tch::Device;
use rust_bert::marian::{
MarianSourceLanguages,MarianTargetLanguages,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use derive_more::{Display, Error};
use std::sync::{Arc, RwLock};
use anyhow::Error;
#[derive(Deserialize)]
struct Input {
text: String,
}
#[derive(Serialize)]
struct Output {
generation_text: String,
}
struct ModelFile {
config_resource:LocalResource,
weights: Arc<RwLock<Vec<u8>>>,
vocab_resource: LocalResource,
merges_resource: LocalResource,
}
impl ModelFile {
fn new(model_path:String, config_path:String, vocab_path:String, merges_path:String) -> Self {
let weights = Arc::new(RwLock::new(get_weights(model_path.clone()).unwrap()));
let config_resource = LocalResource { local_path: config_path.into(), };
let vocab_resource = LocalResource { local_path: vocab_path.into(), };
let merges_resource = LocalResource { local_path: merges_path.into(), };
Self {
weights,
config_resource,
vocab_resource,
merges_resource,
}
}
fn genertation(&self, input_context:&str) -> Result<impl Responder, MyError> {
let source_languages = MarianSourceLanguages::CHINESE2ENGLISH;
let target_languages = MarianTargetLanguages::CHINESE2ENGLISH;
let translation_config = TranslationConfig::new(
ModelType::Marian,
// ModelResource::Torch(Box::new(BufferResource { data: self.weights })),
ModelResource::Torch(Box::new(BufferResource { data: Arc::clone(&self.weights) })),
self.config_resource.clone(),
self.vocab_resource.clone(),
Some(self.merges_resource.clone()),
source_languages,
target_languages,
// Device::Cpu,
Device::Cuda(3),
);
let model = TranslationModel::new(translation_config).map_err(|e| {
MyError::ModelLoadError
})?;
// let output = model.translate(&[input_context.to_string()], None, None);
let output = model.translate(&[input_context.to_string()], None, None).map_err(|e| {
MyError::TranslateError
});
match output {
Ok(vec) => {
if let Some(first_element) = vec.get(0) {
Ok(web::Json(Output { generation_text: first_element.to_string(),
}))
}
else{
Err(MyError::TranslateError)
}
}
Err(error) => {
// Handle the error case
Err(MyError::TranslateError)
}
}
}
}
#[derive(Debug, Display, Error)]
enum MyError {
#[display(fmt = "translationModel load error")]
ModelLoadError,
#[display(fmt = "translate error")]
TranslateError,
}
impl error::ResponseError for MyError {
fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.status_code())
.insert_header(ContentType::html())
.body(self.to_string())
}
fn status_code(&self) -> StatusCode {
match *self {
MyError::ModelLoadError => StatusCode::INTERNAL_SERVER_ERROR,
MyError::TranslateError => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
#[post("/v1/predict")]
async fn predicet_post(
info: web::Json<Input>,
appdata: web::Data<ModelFile>,
) -> Result<impl Responder, MyError> {
let result = appdata.genertation(&info.text);
result
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let appdata = ModelFile::new(
"/root/.cache/.rustbert/opus-mt-zh-en/rust_model.ot".to_string(),
"/root/.cache/.rustbert/opus-mt-zh-en/config.json".to_string(),
"/root/.cache/.rustbert/opus-mt-zh-en/vocab.json".to_string(),
"/root/.cache/.rustbert/opus-mt-zh-en/source.spm".to_string(),
);
let appdata = web::Data::new(appdata);
HttpServer::new(move || {
App::new()
// .app_data(web::Data::clone(&appdata.clone()))
.app_data(web::Data::clone(&appdata))
.service(predicet_post)
// .service(index2)
})
.workers(4)
.bind(("0.0.0.0", 8090))?
.run()
.await
}
fn get_weights(model_path: String) -> anyhow::Result<Vec<u8>, anyhow::Error> {
Ok(std::fs::read(model_path)?)
}
Hello @wolf-li ,
Do you compile the code in release
mode with all optimizations?
Each time you call the generation
function, it will creat a new model (load from your disk and init
) , I think it would be the cause of it.