Add `PoolOptions::before_connect` to execute a function before creating a new connection in the pool
Is your feature request related to a problem? Please describe. When connecting to RDS using IAM I would need Sqlx to call a function that provide a new connection string (or PoolOptions) before creating a connection
Looking at some workaround, none of them work in my case. For example using set_connection_options() does not solve my problem since the password of PgPoolOptions is going to be invalidated 15min after my app started
#[tokio::main]
async fn main() {
let pool = get_pg_pool()
.await
.set_connect_options(get_pg_options())
.await;
let state = SharedState {
pool,
...
};
let app = Router::new()...;
...
axum::serve(listener, app).await.unwrap();
}
Describe the solution you'd like Be able to execute a function that creates a new URL string (or PoolOptions object) right before creating a new connection in the pool Example of how I could use it
#[tokio::main]
async fn main() {
let pool = get_pg_pool()
.await
.before_connect(function_to_execute_right_before_creating_a_new_connection);
....
Describe alternatives you've considered If I create all the connections at the beginning if any of them dies, it won't be possible to create a new one since the IAM auth password is valid only for 15min.
So basically the following does not solve my problem either
PgPoolOptions::new()
.min_connections(max_connections)
.max_connections(max_connections)
...
If any connection is reaped by max_lifetime or idle_timeout or explicitly closed, and it brings the connection count below this amount, a new connection will be opened to replace it. https://docs.rs/sqlx/latest/sqlx/pool/struct.PoolOptions.html#method.min_connections
Another alternative coming from a conversation from Discord would be to create my own custom Database to override the connect method. This sounds like some work and pretty risky without much context about how Sqlx is built
Hmm, actually, I guess you could do it by having a custom Database implementation which wraps the underlying postgres one. Then that Database implementation would have a ConnectOptions::connect implementation which fetches a fresh password.
Additional context This is because we use AWS RDS IAM for which the password is valid only for 15min. So if any connection dies after 15min then creating a new connection with the initial PgPoolOptions or connection string will simply fail
How did y'all end up solving this? I see we have #3562 and #3851 that both attempted to solve this in favour of some upcoming work that will make the whole thing more flexible - but what is everyone doing in the meantime?
How did y'all end up solving this? I see we have #3562 and #3851 that both attempted to solve this in favour of some upcoming work that will make the whole thing more flexible - but what is everyone doing in the meantime?
👋 High level
- Function to generate the RDS IAM password
- Get it and create he connection pool when you boot up your service
- Have a background process to renew the password every 15minutes
Note: This has been running in production for 14 months without any issue
Note 2: Bottom line is to use set_connect_options and pass a new PgConnectOptions object
Generate the RDS IAM password
// https://github.com/awslabs/aws-sdk-rust/issues/951#issuecomment-1838117702
// https://github.com/awslabs/aws-sdk-rust/issues/951#issuecomment-1961010056
async fn generate_rds_iam_token(host: &str, port: u16, user: &str) -> Result<String, KarmaError> {
let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await;
let credentials = config
.credentials_provider()
.expect("no credentials provider found")
.provide_credentials()
.await
.expect("unable to load credentials");
let identity = credentials.into();
let region = config.region().unwrap().to_string();
let mut signing_settings = SigningSettings::default();
signing_settings.expires_in = Some(Duration::from_secs(900));
signing_settings.signature_location = aws_sigv4::http_request::SignatureLocation::QueryParams;
let signing_params = v4::SigningParams::builder()
.identity(&identity)
.region(®ion)
.name("rds-db")
.time(SystemTime::now())
.settings(signing_settings)
.build()?;
let url = format!("https://{host}:{port}/?Action=connect&DBUser={user}");
let signable_request =
SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))
.expect("signable request");
let (signing_instructions, _signature) =
sign(signable_request, &signing_params.into())?.into_parts();
let mut url = url::Url::parse(&url).unwrap();
for (name, value) in signing_instructions.params() {
url.query_pairs_mut().append_pair(name, value);
}
let response = url.to_string().split_off("https://".len());
Ok(response)
}
PG Option pools
async fn get_pg_options(endpoint_type: EndpointType) -> PgConnectOptions {
let user = env::var("POSTGRES_USER").expect("Environment variable POSTGRES_USER not found");
let port = env::var("POSTGRES_PORT")
.expect("Environment variable POSTGRES_PORT not found")
.parse::<u16>()
.expect("Environment variable POSTGRES_PORT is not a valid port number");
let database = env::var("POSTGRES_DB").expect("Environment variable POSTGRES_DB not found");
let host = match endpoint_type {
EndpointType::Write => env::var("POSTGRES_HOST_WRITE")
.expect("Environment variable POSTGRES_HOST_WRITE not found"),
EndpointType::ReadOnly => env::var("POSTGRES_HOST_READ_ONLY")
.expect("Environment variable POSTGRES_HOST_READ_ONLY not found"),
};
let password = match Environment::new() {
Environment::Local => env::var("POSTGRES_USER").unwrap_or("postgres".to_string()),
_ => generate_rds_iam_token(&host, port, &user).await.unwrap(),
};
let encoded_password = encode(&password).to_string();
let url = format!("postgres://{user}:{encoded_password}@{host}:{port}/{database}");
url.parse().unwrap()
}
Update the pool in the background
pub fn update_pool_password_in_background(state: SharedState) {
tokio::spawn(async move {
// Fetch new pool password every 14 minutes
let mut interval = time::interval(time::Duration::from_secs(60 * 14));
loop {
interval.tick().await;
let read_only_opts = ReadOnlyPgPool::get_pg_options().await;
let write_opts = WritePgPool::get_pg_options().await;
let state = state.write().unwrap();
state
.read_only_pool
.clone()
.inner()
.set_connect_options(read_only_opts);
state
.write_pool
.clone()
.inner()
.set_connect_options(write_opts);
}
});
}
Thanks @MattDelac ! This is also what I ended up with as an overall strategy, so super happy to hear it's working in production!