redis-async-rs
redis-async-rs copied to clipboard
Improve `PubsubConnection` handling for multiple subscribers to the same topic (or document limitation)
So currently when users try to call pubsub_connection.subscribe("hello") two times in a row, the pubsub connection will end the first stream and replace it with the second stream.
This was unexpected behavior for me because I expected each pubsub message to be delivered to each subscriber when the same topic is subscribed to multiple times. This broadcast behavior would be inline with the behavior of the Redis server itself (it sends each message to all callers of SUBSCRIBE hello across each client.
My use case is a websocket server that notifies on pubsub events, and currently with redis-async-rs the only way to handle it would be to create a client-per-connection which would be a lot of overhead and general feels like fighting the tools. I believe this should be improved in the long run or at least documented in the short run.
After looking through the source here to see if I could make a PR, I can see that there are pretty deep assumptions that each topic will only have a single consumer, from things like subscribers being a map from topic to mpsc::Sender or even PubsubStream calling unsubscribe automatically on drop.
I came up with a similar PubsubConnection client on my own that has the broadcast semantics, but it won't be as easy as copying this into redis-async-rs because the API surface is different in a way that isn't obvious to remedy. It could be put into async-redis-rs with a major version bump or it could be used as a reference for future work, so I am sharing here. Note that this client handles reconnection + resubscription entirely on its own but it just notifies the subscribers when these reconnecting events are occurring and when they miss messages (as opposed to current client that lets the subscribers manage re-subscribing).
use std::{future::Future, sync::Arc, time::Duration};
use ahash::{AHashMap, AHashSet};
use anyhow::Context;
use futures::{SinkExt, Stream, StreamExt};
use redis_async::{
client::connect::{connect_with_auth, RespConnection},
resp::RespValue,
resp_array,
};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
use tokio_util::sync::CancellationToken;
pub type RespReceiver = broadcast::Receiver<PubsubEvent>;
pub type RespSender = broadcast::Sender<PubsubEvent>;
pub type SubscribeResponder = oneshot::Sender<Result<RespReceiver, anyhow::Error>>;
#[derive(Clone)]
pub enum PubsubEvent {
Message(RespValue),
MessageMissed,
Reconnecting,
}
#[derive(Clone)]
pub struct PubsubClient {
work_sender: mpsc::Sender<PubsubSubscribe>,
#[allow(dead_code)]
// Keep around so the driver future doesn't get dropped
phanton_rx: Arc<oneshot::Receiver<()>>,
}
pub struct PubsubSubscribe {
topic: String,
notify: SubscribeResponder,
}
enum PubSubMessage {
Subscribe(String),
Unsubscribe(String),
Message(String, RespValue),
}
impl TryFrom<RespValue> for PubSubMessage {
type Error = anyhow::Error;
// Parse a message from the Redis PubSub server
fn try_from(value: RespValue) -> Result<Self, anyhow::Error> {
match value {
RespValue::Array(mut arr) => {
if arr.len() < 3 {
return Err(anyhow::anyhow!("Invalid PubSub message, too few elements"));
}
let (kind, topic, message) = (arr.remove(0), arr.remove(0), arr.remove(0));
match (kind, topic) {
(RespValue::BulkString(kind), RespValue::BulkString(topic)) => {
let kind = std::str::from_utf8(&kind)?;
let topic = std::str::from_utf8(&topic)?;
match kind {
"subscribe" => Ok(PubSubMessage::Subscribe(topic.to_string())),
"unsubscribe" => Ok(PubSubMessage::Unsubscribe(topic.to_string())),
"message" => Ok(PubSubMessage::Message(topic.to_string(), message)),
k => Err(anyhow::anyhow!(
"Invalid PubSub message, unknown kind: {}",
k
)),
}
}
_ => Err(anyhow::anyhow!(
"Invalid PubSub message, Non-bulkstring response"
)),
}
}
_ => Err(anyhow::anyhow!(
"Invalid PubSub message, Non-array response"
)),
}
}
}
enum TopicState {
Subscribed(RespSender),
Unsubscribing(Vec<SubscribeResponder>),
WaitingSubscription(Vec<SubscribeResponder>),
}
impl PubsubClient {
pub async fn subscribe(
&self,
channel: &str,
) -> anyhow::Result<impl Stream<Item = PubsubEvent>> {
let (notify_tx, notify_rx) = oneshot::channel();
let work = PubsubSubscribe {
topic: channel.to_string(),
notify: notify_tx,
};
if let Err(e) = self.work_sender.send(work).await {
return Err(anyhow::Error::new(e));
}
match notify_rx.await {
Ok(Ok(receiver)) => {
let stream = BroadcastStream::new(receiver).filter_map(|x| {
futures_util::future::ready(match x {
Ok(x) => Some(x),
Err(BroadcastStreamRecvError::Lagged(lagged)) => {
tracing::error!("Lagged in subscription: {}", lagged);
Some(PubsubEvent::MessageMissed)
}
})
});
Ok(stream)
}
Ok(Err(e)) => Err(e),
Err(e) => Err(anyhow::anyhow!(e)),
}
}
pub async fn new(
host: String,
port: u16,
cancellation_token: CancellationToken,
) -> anyhow::Result<(Self, impl Future<Output = ()>)> {
let (mut msg_tx, msg_rx) = oneshot::channel::<()>();
let (work_tx, mut work_rx) = mpsc::channel::<PubsubSubscribe>(1);
let driver_future = async move {
let loop_ct = cancellation_token.clone();
let mut topic_states: AHashMap<String, TopicState> = AHashMap::new();
while !cancellation_token.is_cancelled() {
let mut resp_client = match pubsub_connect(host.clone(), port).await {
Ok(resp_client) => resp_client,
Err(e) => {
tracing::error!("Error in Redis PubSub: {:?}", e);
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
};
let should_subscribe: AHashSet<String> = topic_states
.iter()
.filter_map(|(topic, state)| match state {
TopicState::Subscribed(_) => Some(topic.to_string()),
TopicState::Unsubscribing(s) => {
if s.is_empty() {
None
} else {
Some(topic.to_string())
}
}
TopicState::WaitingSubscription(_) => Some(topic.to_string()),
})
.collect();
// Subscribe to all
if !should_subscribe.is_empty() {
tracing::info!("Resubscribing to topics: {:?}", should_subscribe);
let mut command: Vec<RespValue> = vec!["SUBSCRIBE".into()];
for topic in should_subscribe {
command.push(topic.into());
}
if let Err(e) = resp_client.send(RespValue::Array(command)).await {
tracing::error!("Error in Redis PubSub: {:?}", e);
continue;
}
}
// Now if this succeeds, then we can move all unsubscribing with subscribers to subscribed and remove others
let total_keys = topic_states.keys().cloned().collect::<Vec<String>>();
for key in total_keys {
if let Some(state) = topic_states.remove(&key) {
match state {
TopicState::Subscribed(subscriber) => {
subscriber.send(PubsubEvent::Reconnecting).ok();
topic_states.insert(key, TopicState::Subscribed(subscriber));
}
TopicState::Unsubscribing(subscribers) => {
if subscribers.is_empty() {
continue;
}
topic_states
.insert(key, TopicState::WaitingSubscription(subscribers));
}
TopicState::WaitingSubscription(subscribers) => {
if subscribers.is_empty() {
continue;
}
topic_states
.insert(key, TopicState::WaitingSubscription(subscribers));
}
}
}
}
async {
loop {
tokio::select! {
_ = msg_tx.closed() => {
break;
}
_ = loop_ct.cancelled() => {
break;
}
// Receive packets from the server
pkt = resp_client.next() => {
// Forward it to msg_tx. If msg_tx is full, drop it.
match pkt {
Some(Ok(pkt)) => {
let pubsub_message = match PubSubMessage::try_from(pkt) {
Ok(m) => m,
Err(e) => {
tracing::error!("Error in Redis PubSub: {:?}", e);
break;
}
};
match pubsub_message {
PubSubMessage::Subscribe(topic) => {
let Some(current_state) = topic_states.remove(&topic) else {
tracing::error!("Received subscribe message for unknown topic: {}", topic);
continue;
};
match current_state {
TopicState::Subscribed(s) => {
tracing::info!("Received subscribe message for already subscribed topic: {}. Notifying", topic);
s.send(PubsubEvent::MessageMissed).ok();
topic_states.insert(topic, TopicState::Subscribed(s));
continue;
}
TopicState::Unsubscribing(subscribers) => {
tracing::error!("Received subscribe message for unsubscribing topic: {}", topic);
let (tx, _) = broadcast::channel(1);
for subscriber in subscribers {
let _ = subscriber.send(
Ok(tx.subscribe()),
);
}
topic_states.insert(topic, TopicState::Subscribed(tx));
}
TopicState::WaitingSubscription(subscribers) => {
let (tx, _) = broadcast::channel(1);
for subscriber in subscribers {
let _ = subscriber.send(
Ok(tx.subscribe()),
);
}
topic_states.insert(topic, TopicState::Subscribed(tx));
}
}
}
PubSubMessage::Unsubscribe(topic) => {
let Some(current_state) = topic_states.remove(&topic) else {
tracing::error!("Received unsubscribe message for unknown topic: {}", topic);
continue;
};
match current_state {
TopicState::Subscribed(tx) => {
tracing::error!("Received unsubscribe message for subscribed topic: {}. Restarting...", topic);
topic_states.insert(topic, TopicState::Subscribed(tx));
break;
}
TopicState::Unsubscribing(subscribers) => {
if subscribers.is_empty() {
continue;
}
topic_states.insert(topic.clone(), TopicState::WaitingSubscription(subscribers));
// Try to send a new subscribe message
let command = resp_array!["SUBSCRIBE", topic.clone()];
tracing::info!("Resubscribing to topic bc somebody wants to: {}", topic);
if let Err(e) = resp_client.send(command).await {
tracing::error!("Error in Redis PubSub: {:?}", e);
break;
}
}
TopicState::WaitingSubscription(s) => {
tracing::error!("Received unsubscribe message for waiting topic: {}. Restarting...", topic);
topic_states.insert(topic.clone(), TopicState::WaitingSubscription(s));
break;
}
}
}
PubSubMessage::Message(topic, message) => {
// Forward to subscribers
match topic_states.remove(&topic) {
Some(TopicState::Subscribed(subscriber)) => {
// Check if the subscriber is still alive
if subscriber.receiver_count() == 0 {
tracing::warn!("No subscribers, so we should unsubscribe");
// No subscribers, so we should unsubscribe
let command = resp_array!["UNSUBSCRIBE", topic.clone()];
if let Err(e) = resp_client.send(command).await {
tracing::error!("Error in Redis PubSub: {:?}", e);
break;
}
topic_states.insert(topic.clone(), TopicState::Unsubscribing(vec![]));
continue;
} else {
let _ = subscriber.send(PubsubEvent::Message(message));
topic_states.insert(topic.clone(), TopicState::Subscribed(subscriber));
}
}
Some(TopicState::Unsubscribing(s)) => {
topic_states.insert(topic.clone(), TopicState::Unsubscribing(s));
tracing::warn!("Received message for unsubscribing topic: {}", topic);
}
Some(TopicState::WaitingSubscription(s)) => {
topic_states.insert(topic.clone(), TopicState::WaitingSubscription(s));
tracing::warn!("Received message for waiting topic: {}", topic);
}
None => {
tracing::error!("Received message for unknown topic: {}", topic);
// Unsubscribe here
let command = resp_array!["UNSUBSCRIBE", topic.clone()];
if let Err(e) = resp_client.send(command).await {
tracing::error!("Error in Redis PubSub: {:?}", e);
break;
}
}
}
}
}
}
Some(Err(e)) => {
tracing::error!("Error in Redis PubSub: {:?}", e);
break;
}
None => {
tracing::info!("Redis PubSub closed, restarting...");
break;
}
}
}
// Receive packets from my packet sender
work_tx = work_rx.recv() => {
match work_tx {
Some(PubsubSubscribe { topic, notify }) => {
let state = topic_states.remove(&topic);
match state {
Some(TopicState::Subscribed(subscriber)) => {
let _ = notify.send(Ok(subscriber.subscribe()));
topic_states.insert(topic, TopicState::Subscribed(subscriber));
}
Some(TopicState::Unsubscribing(mut subscribers)) => {
subscribers.push(notify);
topic_states.insert(topic, TopicState::Unsubscribing(subscribers));
}
Some(TopicState::WaitingSubscription(mut subscribers)) => {
subscribers.push(notify);
topic_states.insert(topic, TopicState::WaitingSubscription(subscribers));
}
None => {
let command = resp_array!["SUBSCRIBE", topic.clone()];
if let Err(e) = resp_client.send(command).await {
let _ = notify.send(Err(anyhow::Error::new(e)));
continue;
}
topic_states.insert(topic, TopicState::WaitingSubscription(vec![notify]));
}
}
}
None => {
tracing::info!("work_rx closed, shutting down");
break;
}
}
}
}
}
}
.await;
}
};
Ok((
PubsubClient {
phanton_rx: Arc::new(msg_rx),
work_sender: work_tx,
},
driver_future,
))
}
}
pub async fn pubsub_connect(host: String, port: u16) -> anyhow::Result<RespConnection> {
let keepalive: Duration = Duration::from_secs(60);
let timeout: Duration = Duration::from_secs(30);
connect_with_auth(
&host,
port,
None,
None,
false,
Some(keepalive),
Some(timeout),
)
.await
.context("Failed to connect to Redis PubSub")
}