Thruster icon indicating copy to clipboard operation
Thruster copied to clipboard

external middleware (plugins/libraries)

Open tronikelis opened this issue 1 year ago • 26 comments

Hi, so I wanted to try writing a couple middlewares for thruster, but I am hitting a road block where some custom config is needed for the middleware.

For example, the rate limit middleware:

use async_trait::async_trait;
use redis::{aio::Connection, AsyncCommands, RedisError, RedisResult};
use thruster::{middleware_fn, Context, MiddlewareNext, MiddlewareResult};

pub struct RateLimiter<T: Store> {
    pub max: usize,
    pub per_ms: usize,
    pub store: T,
}

#[async_trait]
pub trait Store {
    type Error;
    async fn get(&mut self, key: &str) -> Result<Option<usize>, Self::Error>;
    async fn set(&mut self, key: &str, value: usize, expiry_ms: usize) -> Result<(), Self::Error>;
}

pub struct RedisStore {
    url: String,
    connection: Connection,
}

impl RedisStore {
    pub async fn new(url: String) -> RedisResult<Self> {
        let client = redis::Client::open(url.as_str())?;
        let connection = client.get_async_connection().await?;
        return Ok(Self { connection, url });
    }
}

#[async_trait]
impl Store for RedisStore {
    type Error = RedisError;

    async fn get(&mut self, key: &str) -> Result<Option<usize>, Self::Error> {
        let current: Option<usize> = self.connection.get(key).await?;
        return Ok(current);
    }

    async fn set(&mut self, key: &str, value: usize, expiry_ms: usize) -> Result<(), Self::Error> {
        let _: () = self.connection.set_ex(key, value, expiry_ms).await?;
        return Ok(());
    }
}

#[middleware_fn]
pub async fn rate_limiter_middleware<T: 'static + Context + Store + Send>(
    mut context: T,
    next: MiddlewareNext<T>,
) -> MiddlewareResult<T> {
    // how do I get the RateLimiter struct data here and write logic?
    // I would need middleware with pre-defined state for it, not global server state 🤔

    todo!("?");
}

How would I get the RateLimiter struct into the fn rate_limiter_middleware ?

Ideally there would be something like this:

#[tokio::main]
async fn main() {
    let rate_limiter = RateLimiter {
        max: 100,
        per_ms: 1000,
        store: RedisStore::new("redis://127.0.0.1".to_string())
            .await
            .unwrap(),
    };

    let app = App::<HyperRequest, ReqContext, ServerState>::create(
        init_context,
        ServerState::new().await,
    )
    // I am thinking of something similar to this
    .middleware_with_state("/", rate_limiter, rate_limiter_middleware)
    .get("/", m![root]);

    let server = HyperServer::new(app);
    server.build("127.0.0.1", 3000).await;
}

I am learning rust so can't say whether this would be a simple and scalable approach, but let me know if there is a way to pass config to middlewares 👍

tronikelis avatar May 16 '23 05:05 tronikelis