[WIP] Proposal: Create Mistral.rs Server Core Lib
Would you consider exposing the underlying mistral.rs server implementation as a lib so that it can be reused in existing axum projects (as opposed to running mistral.rs server alongside and proxying).
I have two use cases in mind:
- Provide the existing mistral.rs server endpoints as is in an existing app (say for a self hosted / internal inference API).
- Use the underlying mistral.rs server components in a custom endpoint without proxying / duplicating the implementation (say for chat where you want to log requests / response or intercept the request for RAG).
This PR as of right now is a rough proof of concept, but I wanted to get feedback on the idea before going too far. The commits are not squashed so you can see the thought process nor is it fully implemented.
Summary of changes
- Create a new lib crate
mistralrs-server-corethat takes the existingmistralrs-serverand implements it as a lib- The idea is that (if approved) then
mistralrs-servercould just provide the default implementation ofmistralrs-server-core(plus interactive mode, etc.)
- The idea is that (if approved) then
- Adjust the mistralrs server bootstrapping to be more modular so that mistralrs can be bootstrapped separately, stored and reused in the calling app and also expose the routes and utoipa so they can be loaded in an existing axum app
- For bootstrapping / startup it didn't seem like the overhead of additional function calls would be an issue (maybe they get inlined anyway)
- Adjust the chatcompletions functionality to be more modular so that its parts can be used in another app for a custom use case
- Not sure if the additional function calls would be a concern here, but it seems like compared to the inference, function call overhead would be minimal
- Adjust the chatcompletions streamer to take an optional on_done callback hook and also a flag if the chunks should be stored
- I figure that storing the chunks has overhead so don't do it unless asked for
- I also considered maybe just to store the chunks if a callback is set, but it seemed that being explicit was better
Diffs (WIP)
main.rs (lib.rs)
diff --git a/mistralrs-server/src/main.rs b/mistralrs-server-core/src/lib.rs
index c5c5418e..e367c419 100644
--- a/mistralrs-server/src/main.rs
+++ b/mistralrs-server-core/src/lib.rs
@@ -9,10 +9,10 @@ use candle_core::Device;
use clap::Parser;
use mistralrs_core::{
get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, initialize_logging,
- paged_attn_supported, parse_isq_value, BertEmbeddingModel, DefaultSchedulerMethod,
- DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, IsqType, Loader, LoaderBuilder,
- MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelSelected, PagedAttentionConfig, Request,
- SchedulerConfig, TokenSource,
+ paged_attn_supported, parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel,
+ DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, IsqType,
+ Loader, LoaderBuilder, MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelSelected,
+ PagedAttentionConfig, Pipeline, Request, SchedulerConfig, TokenSource,
};
use openai::{
ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, Message, ModelObjects,
@@ -22,11 +22,10 @@ use serde::{Deserialize, Serialize};
use speech_generation::speech_generation;
use std::{num::NonZeroUsize, sync::Arc};
-mod chat_completion;
+pub mod chat_completion;
mod completions;
mod image_generation;
-mod interactive_mode;
-mod openai;
+pub mod openai;
mod speech_generation;
mod util;
@@ -37,152 +36,217 @@ use crate::{
image_generation::image_generation,
};
-use interactive_mode::interactive_mode;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::info;
use utoipa::{OpenApi, ToSchema};
use utoipa_swagger_ui::SwaggerUi;
+mod arg_defaults {
+ pub const TRUNCATE_SEQUENCE: bool = false;
+ pub const MAX_SEQS: usize = 16;
+ pub const NO_KV_CACHE: bool = false;
+ pub const INTERACTIVE_MODE: bool = false;
+ pub const PREFIX_CACHE_N: usize = 16;
+ pub const NO_PAGED_ATTN: bool = false;
+ pub const PAGED_ATTN: bool = false;
+ pub const CPU: bool = false;
+ pub const ENABLE_SEARCH: bool = false;
+ pub const ENABLE_THINKING: bool = false;
+
+ pub fn default_none<T>() -> Option<T> {
+ None
+ }
+
+ pub fn default_token_source() -> crate::TokenSource {
+ crate::TokenSource::CacheToken
+ }
+
+ // Helper function for placeholder model (used in Default impl)
+ pub fn placeholder_model() -> crate::ModelSelected {
+ crate::ModelSelected::Toml {
+ file: String::from("/this/is/just/a/placeholder"),
+ }
+ }
+}
+
// NOTE(EricLBuehler): Accept up to 50mb input
const N_INPUT_SIZE: usize = 50;
const MB_TO_B: usize = 1024 * 1024; // 1024 kb in a mb
-fn parse_token_source(s: &str) -> Result<TokenSource, String> {
- s.parse()
-}
+pub type SharedMistralState = Arc<MistralRs>;
+pub type ExtractedMistralState = State<SharedMistralState>;
+type LoadedPipeline = Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>;
#[derive(Parser)]
#[command(version, about, long_about = None)]
-struct Args {
+pub struct Args {
/// IP to serve on. Defaults to "0.0.0.0"
#[arg(long)]
- serve_ip: Option<String>,
+ pub serve_ip: Option<String>,
/// Integer seed to ensure reproducible random number generation.
#[arg(short, long)]
- seed: Option<u64>,
+ pub seed: Option<u64>,
/// Port to serve on.
#[arg(short, long)]
- port: Option<String>,
+ pub port: Option<String>,
/// Log all responses and requests to this file
#[clap(long, short)]
- log: Option<String>,
+ pub log: Option<String>,
/// If a sequence is larger than the maximum model length, truncate the number
/// of tokens such that the sequence will fit at most the maximum length.
/// If `max_tokens` is not specified in the request, space for 10 tokens will be reserved instead.
#[clap(long, short, action)]
- truncate_sequence: bool,
+ pub truncate_sequence: bool,
/// Model selector
#[clap(subcommand)]
- model: ModelSelected,
+ pub model: ModelSelected,
/// Maximum running sequences at any time. If the `tgt_non_granular_index` flag is set for X-LoRA models, this will be set to 1.
- #[arg(long, default_value_t = 16)]
- max_seqs: usize,
+ #[arg(long, default_value_t = arg_defaults::MAX_SEQS)]
+ pub max_seqs: usize,
/// Use no KV cache.
- #[arg(long, default_value_t = false)]
- no_kv_cache: bool,
+ #[arg(long, default_value_t = arg_defaults::NO_KV_CACHE)]
+ pub no_kv_cache: bool,
/// Chat template file with a JINJA file with `messages`, `add_generation_prompt`, `bos_token`, `eos_token`, and `unk_token` as inputs.
/// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
#[arg(short, long)]
- chat_template: Option<String>,
+ pub chat_template: Option<String>,
/// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
#[arg(short, long)]
- jinja_explicit: Option<String>,
+ pub jinja_explicit: Option<String>,
/// Source of the token for authentication.
/// Can be in the formats: `literal:<value>`, `env:<value>`, `path:<value>`, `cache` to use a cached token, or `none` to use no token.
/// Defaults to `cache`.
- #[arg(long, default_value_t = TokenSource::CacheToken, value_parser = parse_token_source)]
- token_source: TokenSource,
+ #[arg(long, default_value_t = arg_defaults::default_token_source(), value_parser = parse_token_source)]
+ pub token_source: TokenSource,
/// Enter interactive mode instead of serving a chat server.
#[clap(long, short, action)]
- interactive_mode: bool,
+ pub interactive_mode: bool,
/// Number of prefix caches to hold on the device. Other caches are evicted to the CPU based on a LRU strategy.
- #[arg(long, default_value_t = 16)]
- prefix_cache_n: usize,
+ #[arg(long, default_value_t = arg_defaults::PREFIX_CACHE_N)]
+ pub prefix_cache_n: usize,
/// NOTE: This can be omitted to use automatic device mapping!
/// Number of device layers to load and run on GPU(s). All others will be on the CPU.
/// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
/// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
#[arg(short, long, value_parser, value_delimiter = ';')]
- num_device_layers: Option<Vec<String>>,
+ pub num_device_layers: Option<Vec<String>>,
/// In-situ quantization to apply.
#[arg(long = "isq", value_parser = parse_isq_value)]
- in_situ_quant: Option<IsqType>,
+ pub in_situ_quant: Option<IsqType>,
/// GPU memory to allocate for KV cache with PagedAttention in MBs.
/// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
/// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
#[arg(long = "pa-gpu-mem")]
- paged_attn_gpu_mem: Option<usize>,
+ pub paged_attn_gpu_mem: Option<usize>,
/// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
/// If this is not set and the device is CUDA, it will default to `0.9`.
/// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
/// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
#[arg(long = "pa-gpu-mem-usage")]
- paged_attn_gpu_mem_usage: Option<f32>,
+ pub paged_attn_gpu_mem_usage: Option<f32>,
/// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
/// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
/// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
/// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
#[arg(long = "pa-ctxt-len")]
- paged_ctxt_len: Option<usize>,
+ pub paged_ctxt_len: Option<usize>,
/// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
/// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
#[arg(long = "pa-blk-size")]
- paged_attn_block_size: Option<usize>,
+ pub paged_attn_block_size: Option<usize>,
/// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
- #[arg(long = "no-paged-attn", default_value_t = false)]
- no_paged_attn: bool,
+ #[arg(long = "no-paged-attn", default_value_t = arg_defaults::NO_PAGED_ATTN)]
+ pub no_paged_attn: bool,
/// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
- #[arg(long = "paged-attn", default_value_t = false)]
- paged_attn: bool,
+ #[arg(long = "paged-attn", default_value_t = arg_defaults::PAGED_ATTN)]
+ pub paged_attn: bool,
/// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
#[arg(long = "prompt-batchsize")]
- prompt_chunksize: Option<usize>,
+ pub prompt_chunksize: Option<usize>,
/// Use CPU only
#[arg(long)]
- cpu: bool,
+ pub cpu: bool,
/// Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
#[arg(long = "enable-search")]
- enable_search: bool,
+ pub enable_search: bool,
/// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
#[arg(long = "search-bert-model")]
- search_bert_model: Option<String>,
+ pub search_bert_model: Option<String>,
/// Enable thinking for interactive mode and models that support it.
#[arg(long = "enable-thinking")]
- enable_thinking: bool,
+ pub enable_thinking: bool,
+}
+
+impl Default for Args {
+ fn default() -> Self {
+ Self {
+ serve_ip: arg_defaults::default_none(),
+ seed: arg_defaults::default_none(),
+ port: arg_defaults::default_none(),
+ log: arg_defaults::default_none(),
+ truncate_sequence: arg_defaults::TRUNCATE_SEQUENCE,
+ // Default trait requires all fields to be provided, so provide a placeholder value
+ model: arg_defaults::placeholder_model(),
+ max_seqs: arg_defaults::MAX_SEQS,
+ no_kv_cache: arg_defaults::NO_KV_CACHE,
+ chat_template: arg_defaults::default_none(),
+ jinja_explicit: arg_defaults::default_none(),
+ token_source: arg_defaults::default_token_source(),
+ interactive_mode: arg_defaults::INTERACTIVE_MODE,
+ prefix_cache_n: arg_defaults::PREFIX_CACHE_N,
+ num_device_layers: arg_defaults::default_none(),
+ in_situ_quant: arg_defaults::default_none(),
+ paged_attn_gpu_mem: arg_defaults::default_none(),
+ paged_attn_gpu_mem_usage: arg_defaults::default_none(),
+ paged_ctxt_len: arg_defaults::default_none(),
+ paged_attn_block_size: arg_defaults::default_none(),
+ no_paged_attn: arg_defaults::NO_PAGED_ATTN,
+ paged_attn: arg_defaults::PAGED_ATTN,
+ prompt_chunksize: arg_defaults::default_none(),
+ cpu: arg_defaults::CPU,
+ enable_search: arg_defaults::ENABLE_SEARCH,
+ search_bert_model: arg_defaults::default_none(),
+ enable_thinking: arg_defaults::ENABLE_THINKING,
+ }
+ }
+}
+
+fn parse_token_source(s: &str) -> Result<TokenSource, String> {
+ s.parse()
}
#[utoipa::path(
- get,
- tag = "Mistral.rs",
- path = "/v1/models",
- responses((status = 200, description = "Served model info", body = ModelObjects))
+ get,
+ tag = "Mistral.rs",
+ path = "/v1/models",
+ responses((status = 200, description = "Served model info", body = ModelObjects))
)]
-async fn models(State(state): State<Arc<MistralRs>>) -> Json<ModelObjects> {
+async fn models(State(state): ExtractedMistralState) -> Json<ModelObjects> {
Json(ModelObjects {
object: "list",
data: vec![ModelObject {
@@ -195,10 +259,10 @@ async fn models(State(state): State<Arc<MistralRs>>) -> Json<ModelObjects> {
}
#[utoipa::path(
- get,
- tag = "Mistral.rs",
- path = "/health",
- responses((status = 200, description = "Server is healthy"))
+ get,
+ tag = "Mistral.rs",
+ path = "/health",
+ responses((status = 200, description = "Server is healthy"))
)]
async fn health() -> &'static str {
"OK"
@@ -211,14 +275,14 @@ struct ReIsqRequest {
}
#[utoipa::path(
- post,
- tag = "Mistral.rs",
- path = "/re_isq",
- request_body = ReIsqRequest,
- responses((status = 200, description = "Reapply ISQ to a non GGUF or GGML model."))
+ post,
+ tag = "Mistral.rs",
+ path = "/re_isq",
+ request_body = ReIsqRequest,
+ responses((status = 200, description = "Reapply ISQ to a non GGUF or GGML model."))
)]
async fn re_isq(
- State(state): State<Arc<MistralRs>>,
+ State(state): ExtractedMistralState,
Json(request): Json<ReIsqRequest>,
) -> Result<String, String> {
let repr = format!("Re ISQ: {:?}", request.ggml_type);
@@ -228,52 +292,109 @@ async fn re_isq(
Ok(repr)
}
-fn get_router(state: Arc<MistralRs>) -> Router {
+pub fn get_openapi_doc(base_path: Option<&str>) -> utoipa::openapi::OpenApi {
#[derive(OpenApi)]
#[openapi(
- paths(models, health, chatcompletions),
- components(
- schemas(ModelObjects, ModelObject, ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, StopTokens, Message)),
- tags(
- (name = "Mistral.rs", description = "Mistral.rs API")
- ),
- info(
- title = "Mistral.rs",
- license(
- name = "MIT",
- )
- )
- )]
+ paths(models, health, chatcompletions),
+ components(
+ schemas(ModelObjects, ModelObject, ChatCompletionRequest, CompletionRequest, ImageGenerationRequest, StopTokens, Message)),
+ tags(
+ (name = "Mistral.rs", description = "Mistral.rs API")
+ ),
+ info(
+ title = "Mistral.rs",
+ license(
+ name = "MIT",
+ )
+ )
+ )]
struct ApiDoc;
- let doc = { ApiDoc::openapi() };
+ let mut doc = ApiDoc::openapi();
- let allow_origin = AllowOrigin::any();
- let cors_layer = CorsLayer::new()
- .allow_methods([Method::GET, Method::POST])
- .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
- .allow_origin(allow_origin);
+ if let Some(prefix) = base_path {
+ if !prefix.is_empty() {
+ let mut prefixed_paths = utoipa::openapi::Paths::default();
- Router::new()
- .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
- .route("/v1/chat/completions", post(chatcompletions))
- .route("/v1/completions", post(completions))
- .route("/v1/models", get(models))
- .route("/health", get(health))
- .route("/", get(health))
- .route("/re_isq", post(re_isq))
- .route("/v1/images/generations", post(image_generation))
- .route("/v1/audio/speech", post(speech_generation))
- .layer(cors_layer)
- .layer(DefaultBodyLimit::max(N_INPUT_SIZE * MB_TO_B))
- .with_state(state)
+ let original_paths = std::mem::take(&mut doc.paths.paths);
+
+ for (path, item) in original_paths {
+ let prefixed_path = format!("{}{}", prefix, path);
+ prefixed_paths.paths.insert(prefixed_path, item);
+ }
+
+ prefixed_paths.extensions = doc.paths.extensions.clone();
+
+ doc.paths = prefixed_paths;
+ }
+ }
+
+ doc
}
-#[tokio::main]
-async fn main() -> Result<()> {
- let mut args = Args::parse();
+pub async fn bootstrap_mistralrs_router_from_args(
+ args: Args,
+ include_swagger_routes: bool,
+ base_path: Option<&str>,
+) -> Result<Router> {
initialize_logging();
+ let mistralrs = bootstrap_mistralrs(args).await?;
+
+ // if args.interactive_mode {
+ // interactive_mode(mistralrs, args.throughput_log, args.interactive_search).await;
+ // return Ok(());
+ // }
+
+ // // Needs to be after the .build call as that is where the daemon waits.
+ // let setting_server = if !args.interactive_mode {
+ // let port = args.port.expect("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port`?");
+ // let ip = args.serve_ip.unwrap_or_else(|| "0.0.0.0".to_string());
+
+ // // Create listener early to validate address before model loading
+ // let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
+ // Some((listener, ip, port))
+ // } else {
+ // None
+ // };
+
+ let app = get_router(mistralrs, include_swagger_routes, base_path);
+
+ Ok(app)
+}
+
+pub async fn bootstrap_mistralrs_router_from_state(
+ mistralrs: SharedMistralState,
+ include_swagger_routes: bool,
+ base_path: Option<&str>,
+) -> Result<Router> {
+ initialize_logging();
+
+ // if args.interactive_mode {
+ // interactive_mode(mistralrs, args.throughput_log, args.interactive_search).await;
+ // return Ok(());
+ // }
+
+ // // Needs to be after the .build call as that is where the daemon waits.
+ // let setting_server = if !args.interactive_mode {
+ // let port = args.port.expect("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port`?");
+ // let ip = args.serve_ip.unwrap_or_else(|| "0.0.0.0".to_string());
+
+ // // Create listener early to validate address before model loading
+ // let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
+ // Some((listener, ip, port))
+ // } else {
+ // None
+ // };
+
+ let app = get_router(mistralrs, include_swagger_routes, base_path);
+
+ Ok(app)
+}
+
+pub async fn bootstrap_mistralrs(mut args: Args) -> Result<SharedMistralState> {
+ args = configure_args(args);
+
let tgt_non_granular_index = get_tgt_non_granular_index(&args.model);
let dtype = get_model_dtype(&args.model)?;
let auto_device_map_params = get_auto_device_map_params(&args.model)?;
@@ -292,6 +413,22 @@ async fn main() -> Result<()> {
let max_seq_len = auto_device_map_params.max_seq_len();
+ let device = init_device(args.cpu, args.seed)?;
+ let mapper = init_mapper(&args.num_device_layers, &auto_device_map_params);
+ let no_paged_attn = configure_no_paged_attn(&device, args.no_paged_attn, args.paged_attn);
+
+ // Allocate 0.5 GB of CPU memory just as a placeholder.
+ // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
+ let cache_config = init_cache_config(
+ args.paged_attn_block_size,
+ args.paged_attn_gpu_mem,
+ args.paged_attn_gpu_mem_usage,
+ args.paged_ctxt_len,
+ no_paged_attn,
+ max_seq_len,
+ )?;
+
+ // Configure this last to prevent arg moves
let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
.with_no_kv_cache(args.no_kv_cache)
.with_chat_template(args.chat_template)
@@ -299,11 +436,61 @@ async fn main() -> Result<()> {
.with_jinja_explicit(args.jinja_explicit)
.build()?;
+ print_mistral_server_info(&loader);
+
+ let pipeline: LoadedPipeline = loader.load_model_from_hf(
+ None,
+ args.token_source,
+ &dtype,
+ &device,
+ false,
+ mapper,
+ args.in_situ_quant,
+ cache_config,
+ )?;
+ info!("Model loaded.");
+
+ let scheduler_config = init_scheduler_config(&cache_config, &pipeline, args.max_seqs).await;
+
+ let bert_model = if args.enable_search {
+ Some(
+ args.search_bert_model
+ .map(BertEmbeddingModel::Custom)
+ .unwrap_or_default(),
+ )
+ } else {
+ None
+ };
+
+ // Throughput logging in the server
+ Ok(build_mistralrs(
+ pipeline,
+ scheduler_config,
+ args.interactive_mode,
+ bert_model.clone(),
+ args.log,
+ args.truncate_sequence,
+ args.no_kv_cache,
+ args.prefix_cache_n,
+ ))
+}
+
+// This was originally with the device config
+fn configure_args(mut args: Args) -> Args {
+ #[cfg(not(feature = "metal"))]
+ if args.cpu {
+ args.no_paged_attn = true;
+ }
+
+ args
+}
+
+fn init_device(force_cpu: bool, seed: Option<u64>) -> Result<candle_core::Device> {
#[cfg(feature = "metal")]
let device = Device::new_metal(0)?;
#[cfg(not(feature = "metal"))]
- let device = if args.cpu {
- args.no_paged_attn = true;
+ #[allow(clippy::if_same_then_else)]
+ let device = if force_cpu {
Device::Cpu
} else if mistralrs_core::distributed::use_nccl() {
Device::Cpu
@@ -311,22 +498,19 @@ async fn main() -> Result<()> {
Device::cuda_if_available(0)?
};
- if let Some(seed) = args.seed {
+ if let Some(seed) = seed {
device.set_seed(seed)?;
}
- info!(
- "avx: {}, neon: {}, simd128: {}, f16c: {}",
- candle_core::utils::with_avx(),
- candle_core::utils::with_neon(),
- candle_core::utils::with_simd128(),
- candle_core::utils::with_f16c()
- );
- info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
- info!("Model kind is: {}", loader.get_kind().to_string());
+ Ok(device)
+}
+fn init_mapper(
+ num_device_layers: &Option<Vec<String>>,
+ auto_device_map_params: &AutoDeviceMapParams,
+) -> DeviceMapSetting {
// Parse device mapper
- let mapper = if let Some(device_layers) = args.num_device_layers {
+ if let Some(device_layers) = num_device_layers {
if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
let layers = device_layers[0].parse::<usize>().unwrap();
DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
@@ -358,152 +542,176 @@ async fn main() -> Result<()> {
DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
}
} else {
- DeviceMapSetting::Auto(auto_device_map_params)
- };
+ DeviceMapSetting::Auto(auto_device_map_params.clone())
+ }
+}
+
+#[allow(clippy::borrowed_box)]
+fn print_mistral_server_info(loader: &Box<dyn Loader>) {
+ info!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle_core::utils::with_avx(),
+ candle_core::utils::with_neon(),
+ candle_core::utils::with_simd128(),
+ candle_core::utils::with_f16c()
+ );
+
+ info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
+ info!("Model kind is: {}", loader.get_kind().to_string());
+}
- let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
- args.no_paged_attn
+fn configure_no_paged_attn(device: &Device, no_paged_attn: bool, paged_attn: bool) -> bool {
+ if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
+ no_paged_attn
} else if device.is_metal() {
- !args.paged_attn
+ !paged_attn
} else {
true
- };
+ }
+}
- // Allocate 0.5 GB of CPU memory just as a placeholder.
- // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
- let cache_config = match (
- args.paged_attn_block_size,
- args.paged_attn_gpu_mem,
- args.paged_attn_gpu_mem_usage,
- args.paged_ctxt_len,
+fn init_cache_config(
+ paged_attn_block_size: Option<usize>,
+ paged_attn_gpu_mem: Option<usize>,
+ paged_attn_gpu_mem_usage: Option<f32>,
+ paged_ctxt_len: Option<usize>,
+ no_paged_attn: bool,
+ max_seq_len: usize,
+) -> Result<Option<PagedAttentionConfig>> {
+ match (
+ paged_attn_block_size,
+ paged_attn_gpu_mem,
+ paged_attn_gpu_mem_usage,
+ paged_ctxt_len,
paged_attn_supported(),
no_paged_attn,
) {
- (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
+ (block_size, None, None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::ContextSize(max_seq_len),
- )?),
- (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
+ )?)),
+ (block_size, None, None, Some(ctxt), true, false) => Ok(Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::ContextSize(ctxt),
- )?),
- (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
+ )?)),
+ (block_size, None, Some(f), None, true, false) => Ok(Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::Utilization(f),
- )?),
- (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
+ )?)),
+ (block_size, Some(m), None, None, true, false) => Ok(Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::MbAmount(m),
- )?),
+ )?)),
(block_size, Some(_m), Some(f), None, true, false) => {
info!("Both memory size, and usage were specified, defaulting to the usage value.");
- Some(PagedAttentionConfig::new(
+ Ok(Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::Utilization(f),
- )?)
+ )?))
}
(block_size, Some(_m), None, Some(ctxt), true, false) => {
info!("All memory size and ctxt len, defaulting to the context len value.");
- Some(PagedAttentionConfig::new(
+ Ok(Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::ContextSize(ctxt),
- )?)
+ )?))
}
(block_size, None, Some(f), Some(_ctxt), true, false) => {
info!("Both ctxt len and usage were specified, defaulting to the usage value.");
- Some(PagedAttentionConfig::new(
+ Ok(Some(PagedAttentionConfig::new(
block_size,
512,
MemoryGpuConfig::Utilization(f),
- )?)
+ )?))
}
- (_, _, _, _, _, _) => None,
- };
-
- let pipeline = loader.load_model_from_hf(
- None,
- args.token_source,
- &dtype,
- &device,
- false,
- mapper,
- args.in_situ_quant,
- cache_config,
- )?;
- info!("Model loaded.");
+ (_, _, _, _, _, _) => Ok(None),
+ }
+}
- let scheduler_config = if cache_config.is_some() {
+async fn init_scheduler_config(
+ cache_config: &Option<PagedAttentionConfig>,
+ pipeline: &LoadedPipeline,
+ args_max_seqs: usize,
+) -> SchedulerConfig {
+ if cache_config.is_some() {
// Handle case where we may have device mapping
if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
SchedulerConfig::PagedAttentionMeta {
- max_num_seqs: args.max_seqs,
+ max_num_seqs: args_max_seqs,
config: cache_config.clone(),
}
} else {
SchedulerConfig::DefaultScheduler {
- method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
+ method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
}
}
} else {
SchedulerConfig::DefaultScheduler {
- method: DefaultSchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
+ method: DefaultSchedulerMethod::Fixed(args_max_seqs.try_into().unwrap()),
}
- };
- let bert_model = if args.enable_search {
- Some(
- args.search_bert_model
- .map(BertEmbeddingModel::Custom)
- .unwrap_or_default(),
- )
- } else {
- None
- };
- // Throughput logging in the server
- let mistralrs = MistralRsBuilder::new(
- pipeline,
- scheduler_config,
- !args.interactive_mode,
- bert_model.clone(),
- )
- .with_opt_log(args.log)
- .with_truncate_sequence(args.truncate_sequence)
- .with_no_kv_cache(args.no_kv_cache)
- .with_prefix_cache_n(args.prefix_cache_n)
- .build();
-
- if args.interactive_mode {
- interactive_mode(
- mistralrs,
- bert_model.is_some(),
- args.enable_thinking.then_some(true),
- )
- .await;
- return Ok(());
}
+}
- // Needs to be after the .build call as that is where the daemon waits.
- let setting_server = if !args.interactive_mode {
- let port = args.port.expect("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port`?");
- let ip = args.serve_ip.unwrap_or_else(|| "0.0.0.0".to_string());
+#[allow(clippy::too_many_arguments)]
+fn build_mistralrs(
+ pipeline: LoadedPipeline,
+ scheduler_config: SchedulerConfig,
+ interactive_mode: bool,
+ bert_model: Option<BertEmbeddingModel>,
+ log: Option<String>,
+ truncate_sequence: bool,
+ no_kv_cache: bool,
+ prefix_cache_n: usize,
+) -> SharedMistralState {
+ MistralRsBuilder::new(pipeline, scheduler_config, !interactive_mode, bert_model)
+ .with_opt_log(log)
+ .with_truncate_sequence(truncate_sequence)
+ .with_no_kv_cache(no_kv_cache)
+ .with_prefix_cache_n(prefix_cache_n)
+ .build()
+}
- // Create listener early to validate address before model loading
- let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?;
- Some((listener, ip, port))
- } else {
- None
- };
+fn get_router(
+ state: SharedMistralState,
+ include_swagger_routes: bool,
+ base_path: Option<&str>,
+) -> Router {
+ let allow_origin = AllowOrigin::any();
+ let cors_layer = CorsLayer::new()
+ .allow_methods([Method::GET, Method::POST])
+ .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
+ .allow_origin(allow_origin);
- let app = get_router(mistralrs);
- if let Some((listener, ip, port)) = setting_server {
- info!("Serving on http://{ip}:{}.", port);
- axum::serve(listener, app).await?;
- };
+ // Use the provided base path or default to ""
+ let prefix = base_path.unwrap_or("");
+
+ let mut router = Router::new()
+ .route("/v1/chat/completions", post(chatcompletions))
+ .route("/v1/completions", post(completions))
+ .route("/v1/models", get(models))
+ .route("/health", get(health))
+ .route("/", get(health))
+ .route("/re_isq", post(re_isq))
+ .route("/v1/images/generations", post(image_generation))
+ .route("/v1/audio/speech", post(speech_generation))
+ .layer(cors_layer)
+ .layer(DefaultBodyLimit::max(N_INPUT_SIZE * MB_TO_B))
+ .with_state(state);
+
+ if include_swagger_routes {
+ let doc = get_openapi_doc(None);
+
+ router = router.merge(
+ SwaggerUi::new(format!("{prefix}/docs"))
+ .url(format!("{prefix}/api-doc/openapi.json"), doc),
+ );
+ }
- Ok(())
+ router
}
chat_completion.rs
diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server-core/src/chat_completion.rs
index 7ae0d76a..07b0655f 100644
--- a/mistralrs-server/src/chat_completion.rs
+++ b/mistralrs-server-core/src/chat_completion.rs
@@ -1,5 +1,5 @@
use serde_json::Value;
-use std::{env, error::Error, ops::Deref, pin::Pin, sync::Arc, task::Poll, time::Duration};
+use std::{env, error::Error, ops::Deref, pin::Pin, task::Poll, time::Duration};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use crate::{
@@ -7,7 +7,7 @@ use crate::{
ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
ResponseFormat, StopTokens,
},
- util,
+ util, ExtractedMistralState, SharedMistralState,
};
use anyhow::Context;
use anyhow::Result;
@@ -23,11 +23,15 @@ use either::Either;
use indexmap::IndexMap;
use itertools::Itertools;
use mistralrs_core::{
- ChatCompletionResponse, Constraint, DrySamplingParams, MistralRs, NormalRequest, Request,
- RequestMessage, Response, SamplingParams, StopTokens as InternalStopTokens,
+ ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, DrySamplingParams, MistralRs,
+ NormalRequest, Request, RequestMessage, Response, SamplingParams,
+ StopTokens as InternalStopTokens,
};
use serde::Serialize;
+/// A hook that runs when the stream finishes, receiving all of the chunks.
+pub type OnDoneCallback = Box<dyn Fn(&[ChatCompletionChunkResponse]) + Send + Sync>;
+
#[derive(Debug)]
struct ModelErrorMessage(String);
impl std::fmt::Display for ModelErrorMessage {
@@ -46,7 +50,10 @@ enum DoneState {
pub struct Streamer {
rx: Receiver<Response>,
done_state: DoneState,
- state: Arc<MistralRs>,
+ state: SharedMistralState,
+ store_chunks: bool,
+ chunks: Vec<ChatCompletionChunkResponse>,
+ on_done: Option<OnDoneCallback>,
}
impl futures::Stream for Streamer {
@@ -64,6 +71,9 @@ impl futures::Stream for Streamer {
return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
}
DoneState::Done => {
+ if let Some(on_done) = &self.on_done {
+ on_done(&self.chunks);
+ }
return Poll::Ready(None);
}
DoneState::Running => (),
@@ -93,6 +103,11 @@ impl futures::Stream for Streamer {
}
// Done now, just need to send the [DONE]
MistralRs::maybe_log_response(self.state.clone(), &response);
+
+ if self.store_chunks {
+ self.chunks.push(response.clone());
+ }
+
Poll::Ready(Some(Event::default().json_data(response)))
}
Response::Done(_) => unreachable!(),
@@ -172,9 +187,9 @@ impl IntoResponse for ChatCompletionResponder {
}
}
-async fn parse_request(
+pub async fn parse_request(
oairequest: ChatCompletionRequest,
- state: Arc<MistralRs>,
+ state: SharedMistralState,
tx: Sender<Response>,
) -> Result<(Request, bool)> {
let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
@@ -418,6 +433,9 @@ async fn parse_request(
))
}
+pub const CHANNEL_BUFFER_SIZE: usize = 10_000;
+pub const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 10_000;
+
#[utoipa::path(
post,
tag = "Mistral.rs",
@@ -426,72 +444,108 @@ async fn parse_request(
responses((status = 200, description = "Chat completions"))
)]
pub async fn chatcompletions(
- State(state): State<Arc<MistralRs>>,
+ State(state): ExtractedMistralState,
Json(oairequest): Json<ChatCompletionRequest>,
) -> ChatCompletionResponder {
- let (tx, mut rx) = channel(10_000);
+ let (tx, mut rx) = create_response_channel();
+
let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
Ok(x) => x,
- Err(e) => {
- let e = anyhow::Error::msg(e.to_string());
- MistralRs::maybe_log_error(state, &*e);
- return ChatCompletionResponder::InternalError(e.into());
- }
+ Err(e) => return handle_error(state, e.into()),
};
- let sender = state.get_sender().unwrap();
- if let Err(e) = sender.send(request).await {
- let e = anyhow::Error::msg(e.to_string());
- MistralRs::maybe_log_error(state, &*e);
- return ChatCompletionResponder::InternalError(e.into());
+ if let Err(e) = send_request(&state, request).await {
+ return handle_error(state, e.into());
}
if is_streaming {
- let streamer = Streamer {
- rx,
- done_state: DoneState::Running,
- state,
- };
-
- let keep_alive_interval = env::var("KEEP_ALIVE_INTERVAL")
- .map(|val| val.parse::<u64>().unwrap_or(10000))
- .unwrap_or(10000);
- ChatCompletionResponder::Sse(
- Sse::new(streamer)
- .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval))),
- )
+ ChatCompletionResponder::Sse(create_chat_streamer(rx, state, None))
} else {
- let response = match rx.recv().await {
- Some(response) => response,
- None => {
- let e = anyhow::Error::msg("No response received from the model.");
- MistralRs::maybe_log_error(state, &*e);
- return ChatCompletionResponder::InternalError(e.into());
- }
- };
+ process_non_streaming_chat_response(&mut rx, state).await
+ }
+}
- match response {
- Response::InternalError(e) => {
- MistralRs::maybe_log_error(state, &*e);
- ChatCompletionResponder::InternalError(e)
- }
- Response::ModelError(msg, response) => {
- MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
- MistralRs::maybe_log_response(state, &response);
- ChatCompletionResponder::ModelError(msg, response)
- }
- Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
- Response::Done(response) => {
- MistralRs::maybe_log_response(state, &response);
- ChatCompletionResponder::Json(response)
- }
- Response::Chunk(_) => unreachable!(),
- Response::CompletionDone(_) => unreachable!(),
- Response::CompletionModelError(_, _) => unreachable!(),
- Response::CompletionChunk(_) => unreachable!(),
- Response::ImageGeneration(_) => unreachable!(),
- Response::Speech { .. } => unreachable!(),
- Response::Raw { .. } => unreachable!(),
+pub fn handle_error(
+ state: SharedMistralState,
+ e: Box<dyn std::error::Error + Send + Sync + 'static>,
+) -> ChatCompletionResponder {
+ let e = anyhow::Error::msg(e.to_string());
+ MistralRs::maybe_log_error(state, &*e);
+ ChatCompletionResponder::InternalError(e.into())
+}
+
+pub fn create_response_channel() -> (Sender<Response>, Receiver<Response>) {
+ channel(CHANNEL_BUFFER_SIZE)
+}
+
+pub fn get_keep_alive_interval() -> u64 {
+ env::var("KEEP_ALIVE_INTERVAL")
+ .map(|val| val.parse::<u64>().unwrap_or(DEFAULT_KEEP_ALIVE_INTERVAL))
+ .unwrap_or(DEFAULT_KEEP_ALIVE_INTERVAL)
+}
+
+pub async fn send_request(state: &SharedMistralState, request: Request) -> Result<()> {
+ let sender = state.get_sender().unwrap();
+ sender.send(request).await.map_err(|e| e.into())
+}
+
+pub fn create_chat_streamer(
+ rx: Receiver<Response>,
+ state: SharedMistralState,
+ on_done: Option<OnDoneCallback>,
+) -> Sse<Streamer> {
+ let streamer = Streamer {
+ rx,
+ done_state: DoneState::Running,
+ store_chunks: true,
+ state,
+ chunks: Vec::new(),
+ on_done,
+ };
+
+ let keep_alive_interval = get_keep_alive_interval();
+
+ Sse::new(streamer)
+ .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
+}
+
+pub async fn process_non_streaming_chat_response(
+ rx: &mut Receiver<Response>,
+ state: SharedMistralState,
+) -> ChatCompletionResponder {
+ let response = match rx.recv().await {
+ Some(response) => response,
+ None => {
+ let e = anyhow::Error::msg("No response received from the model.");
+ return handle_error(state, e.into());
+ }
+ };
+
+ match_responses(state, response)
+}
+
+pub fn match_responses(state: SharedMistralState, response: Response) -> ChatCompletionResponder {
+ match response {
+ Response::InternalError(e) => {
+ MistralRs::maybe_log_error(state, &*e);
+ ChatCompletionResponder::InternalError(e)
+ }
+ Response::ModelError(msg, response) => {
+ MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
+ MistralRs::maybe_log_response(state, &response);
+ ChatCompletionResponder::ModelError(msg, response)
+ }
+ Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
+ Response::Done(response) => {
+ MistralRs::maybe_log_response(state, &response);
+ ChatCompletionResponder::Json(response)
}
+ Response::Chunk(_) => unreachable!(),
+ Response::CompletionDone(_) => unreachable!(),
+ Response::CompletionModelError(_, _) => unreachable!(),
+ Response::CompletionChunk(_) => unreachable!(),
+ Response::ImageGeneration(_) => unreachable!(),
+ Response::Speech { .. } => unreachable!(),
+ Response::Raw { .. } => unreachable!(),
}
}
Example
I have an example repo here to demonstrate these use cases: https://github.com/matthewhaynesonline/mistral-server-core-test
Questions
- What do you think of the general concept?
- Assuming you're okay with the general concept, do you have concerns with the implementation?
- Would it make sense to do the same treatment on the other endpoints / features?
Walkthrough
This update introduces a new mistralrs-server-core crate, providing core server logic, builder patterns, and OpenAI-compatible API types for the mistral.rs server. The main server binary and its dependencies are refactored to use these abstractions, with routing, configuration, and OpenAPI documentation now modularized. Dependency management and feature flags are aligned across the workspace.
Changes
| File(s) / Path(s) | Change Summary |
|---|---|
| Cargo.toml, mistralrs-server/Cargo.toml | Updated workspace members and dependencies; added mistralrs-server-core, reorganized and aligned dependencies and feature flags for workspace-centric management. |
| mistralrs-server-core/Cargo.toml, README.md | Introduced new crate manifest and README for mistralrs-server-core, specifying metadata, dependencies, features, and documentation placeholder. |
| mistralrs-server-core/src/lib.rs | Added main library file exposing public modules for chat completions, builders, OpenAI compatibility, OpenAPI docs, types, and utilities. |
| mistralrs-server-core/src/types.rs | Added type aliases for shared state, extracted state, and loaded pipeline for server context. |
| mistralrs-server-core/src/util.rs | Added utility function for parsing and loading images from URLs, file paths, or data URLs. |
| mistralrs-server-core/src/openai.rs | Introduced and expanded OpenAI-compatible API types, schema helpers, response formats, and detailed documentation for API interoperability. |
| mistralrs-server-core/src/openapi_doc.rs | Added OpenAPI documentation generator function, supporting optional base path prefixing for API integration. |
| mistralrs-server-core/src/handlers.rs | Added HTTP route handlers for models listing, health check, and ISQ operations, with OpenAPI annotations. |
| mistralrs-server-core/src/chat_completion.rs | Refactored and modularized chat completion logic, adding streaming support, extensible callbacks, improved error handling, and public helper functions. |
| mistralrs-server-core/src/completions.rs, image_generation.rs | Updated handler parameter types to use new extracted shared state alias; no logic changes. |
| mistralrs-server-core/src/mistralrs_for_server_builder.rs | Added comprehensive builder for configuring and constructing a mistral.rs instance with device, model, caching, quantization, and search options. |
| mistralrs-server-core/src/mistralrs_server_router_builder.rs | Added builder for constructing Axum HTTP router with configurable routes, CORS, body size limits, and optional Swagger/OpenAPI docs. |
| mistralrs-server/src/main.rs | Refactored main function to use new builders for model/server setup and router construction, removing manual setup code and delegating to core abstractions. |
| mistralrs-server/src/interactive_mode.rs | Updated import to use utility from mistralrs_server_core. |
Sequence Diagram(s)
sequenceDiagram
participant Main as main.rs
participant Builder as MistralRsForServerBuilder
participant RouterB as MistralRsServerRouterBuilder
participant Axum as Axum Server
Main->>Builder: Configure and build mistral.rs instance (.build().await)
Builder-->>Main: Returns SharedMistralState
Main->>RouterB: Configure and build router with SharedMistralState (.build().await)
RouterB-->>Main: Returns Router
Main->>Axum: Serve router (axum::serve)
sequenceDiagram
participant Client
participant Axum as Axum Router
participant Chat as chatcompletions handler
participant Streamer as Streamer
participant Model as MistralRs
Client->>Axum: POST /v1/chat/completions
Axum->>Chat: Extract request and state
Chat->>Model: Send request via channel
Model-->>Chat: Responds via channel (stream or single)
Chat->>Streamer: Create Streamer for SSE (if streaming)
Streamer-->>Client: Stream SSE events (chunks, [DONE])
Chat-->>Client: Return JSON or error (if not streaming)
Possibly related PRs
- EricLBuehler/mistral.rs#1353: Adds and refines OpenAPI schema derivations for chat completion API types used in the server core, complementing this PR's implementation of server core functionality and API handlers.
Suggested reviewers
- EricLBuehler
Poem
🐇✨
In the warren of code, new tunnels appear,
A server core built, its purpose is clear.
With builders and routers, the logic’s refined,
OpenAPI docs for all rabbits to find.
Now streaming and models are easy to serve—
This bunny approves, with a hop and a swerve!
🐇✨
✨ Finishing Touches
- [ ] 📝 Generate Docstrings
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
🪧 Tips
Chat
There are 3 ways to chat with CodeRabbit:
- Review comments: Directly reply to a review comment made by CodeRabbit. Example:
-
I pushed a fix in commit <commit_id>, please review it. -
Explain this complex logic. -
Open a follow-up GitHub issue for this discussion.
-
- Files and specific lines of code (under the "Files changed" tab): Tag
@coderabbitaiin a new review comment at the desired location with your query. Examples:-
@coderabbitai explain this code block. -
@coderabbitai modularize this function.
-
- PR comments: Tag
@coderabbitaiin a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:-
@coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase. -
@coderabbitai read src/utils.ts and explain its main purpose. -
@coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format. -
@coderabbitai help me debug CodeRabbit configuration file.
-
Support
Need help? Create a ticket on our support page for assistance with any issues or questions.
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.
CodeRabbit Commands (Invoked using PR comments)
-
@coderabbitai pauseto pause the reviews on a PR. -
@coderabbitai resumeto resume the paused reviews. -
@coderabbitai reviewto trigger an incremental review. This is useful when automatic reviews are disabled for the repository. -
@coderabbitai full reviewto do a full review from scratch and review all the files again. -
@coderabbitai summaryto regenerate the summary of the PR. -
@coderabbitai generate docstringsto generate docstrings for this PR. -
@coderabbitai generate sequence diagramto generate a sequence diagram of the changes in this PR. -
@coderabbitai resolveresolve all the CodeRabbit review comments. -
@coderabbitai configurationto show the current CodeRabbit configuration for the repository. -
@coderabbitai helpto get help.
Other keywords and placeholders
- Add
@coderabbitai ignoreanywhere in the PR description to prevent this PR from being reviewed. - Add
@coderabbitai summaryto generate the high-level summary at a specific location in the PR description. - Add
@coderabbitaianywhere in the PR title to generate the title automatically.
CodeRabbit Configuration File (.coderabbit.yaml)
- You can programmatically configure CodeRabbit by adding a
.coderabbit.yamlfile to the root of your repository. - Please see the configuration documentation for more information.
- If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation:
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
Documentation and Community
- Visit our Documentation for detailed information on how to use CodeRabbit.
- Join our Discord Community to get help, request features, and share feedback.
- Follow us on X/Twitter for updates and announcements.
Code Metrics Report
=============================================================================== Language Files Lines Code Comments Blanks =============================================================================== C Header 3 62 53 0 9 CSS 1 428 366 12 50 Dockerfile 1 39 22 9 8 HTML 1 58 46 4 8 JavaScript 7 1221 915 169 137 JSON 14 123 122 0 1 Makefile 1 6 5 0 1 Python 86 4045 3410 161 474 Shell 1 63 26 18 19 Plain Text 3 3723 0 2413 1310 TOML 20 623 565 10 48 YAML 2 21 19 2 0 ------------------------------------------------------------------------------- Jupyter Notebooks 3 0 0 0 0 |- Markdown 2 77 32 31 14 |- Python 2 205 178 1 26 (Total) 282 210 32 40 ------------------------------------------------------------------------------- Markdown 58 5029 0 3832 1197 |- BASH 10 107 101 2 4 |- JSON 2 42 42 0 0 |- Python 7 121 109 0 12 |- Rust 22 757 634 1 122 |- TOML 2 75 63 0 12 (Total) 6131 949 3835 1347 ------------------------------------------------------------------------------- Rust 367 130240 115946 2835 11459 |- Markdown 162 2233 29 1995 209 (Total) 132473 115975 4830 11668 =============================================================================== Total 568 145681 121495 9465 14721 ===============================================================================
@matthewhaynesonline this is interesting!
What do you think of the general concept?
I think it's a great concept!
Assuming you're okay with the general concept, do you have concerns with the implementation?
It looks like (right now) the code is a duplicate of mistralrs-server, including the clap parts. Ideally, if you could abstract those away a bit and make mistralrs-server-core its own crate, that would be great. A builder of some kind comes to mind here.
Would it make sense to do the same treatment on the other endpoints / features?
Not sure what you mean - can you please elaborate?
@matthewhaynesonline this is interesting!
What do you think of the general concept?
I think it's a great concept!
Assuming you're okay with the general concept, do you have concerns with the implementation?
It looks like (right now) the code is a duplicate of
mistralrs-server, including theclapparts. Ideally, if you could abstract those away a bit and makemistralrs-server-coreits own crate, that would be great. A builder of some kind comes to mind here.Would it make sense to do the same treatment on the other endpoints / features?
Not sure what you mean - can you please elaborate?
@EricLBuehler Sounds good! Let me take another pass at this and yup builder does make a lot more sense.
Re. the other endpoints, the first draft PR only deals with the chat completions endpoint / functionality to be recreated / extended in another axum project using the server core (since that was my immediate use case), but not the other ones like regular completions, images, etc. My gut reaction is to be consistent, so if chat completions can be used / extended, then the others (completion, images, etc.) should be able to be too, but there's also the question of what API surface you'd want to expose, perhaps increased complexity, etc.
Here's an example of using the chatcompletions from the server lib in a different axum project (just for illustrative purposes): https://github.com/matthewhaynesonline/mistral-server-core-test/blob/b65da9a7b91d94aae632e2ff0c671931dd8f122a/src/controllers/mod.rs#L23-L66
Maybe I can take a pass implementing it fully / merge ready so you can see it in totality? E.g. cleaned up, builder pattern, server implementing server-core lib, benchmarks, etc.
P.S. the web chat looks really cool!
Maybe I can take a pass implementing it fully / merge ready so you can see it in totality? E.g. cleaned up, builder pattern, server implementing server-core lib, benchmarks, etc.
@matthewhaynesonline sounds great! Following along with the progress - 66ea5fc looks very nice! Let me know when this is ready for review
P.S. the web chat looks really cool!
Thanks! Also added vision support and file uploading, as well as TTS :)
@matthewhaynesonline sounds great! Following along with the progress - 66ea5fc looks very nice! Let me know when this is ready for review
P.S. the web chat looks really cool!
Thanks! Also added vision support and file uploading, as well as TTS :)
@EricLBuehler thanks so much! Almost ready - hoping to have it ready for review for later today. Here's what I think is still left:
- Clean up code (and docs) and make sure I didn't miss anything / leave any messes
- Bench mark against original mistralrs server to make sure it maintains perf
-
Nice to have? Make the other controller actions (completion, image, audio) more modular like chatcompletions is
- Assuming there aren't perf draw backs (and I'd think / hope a release build would handle that the overhead of the added function calls ), I think it would be nice to have a consistent approach where all the functionality can be extended / reused as opposed to chatcompletions being the only a special case
- That said, it might add some time and this PR is getting somewhat large already
- Do you have any preference on one PR for everything or focusing on the PR as is / merging in what's ready now and opening follow ups for further enhancements?
- Clean up the PR itself
- Notes
- Clean up crates (right now there's the original and new mistralrs-server crates in parallel just to make syncing the fork easier and also benchmarking, but naturally that would consolidate down to just one)
- Check for dependency dupes
Still WIP, but I threw together a quick benchmark notebook to compare the existing and new implementations:
https://gist.github.com/matthewhaynesonline/e66415ed253e4d731d09f45195ee9ee1
Was going to take a look at adapting the official benchmark, but wanted a quick check as well before going any further.
Seems to be within the noise (at least on my laptop, within the margin of error between runs), but need to keep testing to be sure
@EricLBuehler I think this is ready for a first review, so I'll mark it as ready in case you want to take a look and also that way CodeRabbit can review too