diff options
Diffstat (limited to 'src/broadcast.rs')
| -rw-r--r-- | src/broadcast.rs | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/src/broadcast.rs b/src/broadcast.rs new file mode 100644 index 0000000..083a301 --- /dev/null +++ b/src/broadcast.rs @@ -0,0 +1,78 @@ +use std::sync::{Arc, Mutex}; + +use futures::{future, stream::StreamExt as _, Stream}; +use tokio::sync::broadcast::{channel, Sender}; +use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; + +// Clones will share the same sender. +#[derive(Clone)] +pub struct Broadcaster<M> { + // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's + // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that + // lock it must be sync. + senders: Arc<Mutex<Sender<M>>>, +} + +impl<M> Default for Broadcaster<M> +where + M: Clone + Send + std::fmt::Debug + 'static, +{ + fn default() -> Self { + let sender = Self::make_sender(); + + Self { + senders: Arc::new(Mutex::new(sender)), + } + } +} + +impl<M> Broadcaster<M> +where + M: Clone + Send + std::fmt::Debug + 'static, +{ + // panic: if ``message.channel.id`` has not been previously registered, + // and was not part of the initial set of channels. + pub fn broadcast(&self, message: &M) { + let tx = self.sender(); + + // Per the Tokio docs, the returned error is only used to indicate that + // there are no receivers. In this use case, that's fine; a lack of + // listening consumers (chat clients) when a message is sent isn't an + // error. + // + // The successful return value, which includes the number of active + // receivers, also isn't that interesting to us. + let _ = tx.send(message.clone()); + } + + // panic: if ``channel`` has not been previously registered, and was not + // part of the initial set of channels. + pub fn subscribe(&self) -> impl Stream<Item = M> + std::fmt::Debug { + let rx = self.sender().subscribe(); + + BroadcastStream::from(rx).scan((), |(), r| { + future::ready(match r { + Ok(event) => Some(event), + // Stop the stream here. This will disconnect SSE clients + // (see `routes.rs`), who will then resume from + // `Last-Event-ID`, allowing them to catch up by reading + // the skipped messages from the database. + // + // See also: + // <https://users.rust-lang.org/t/taking-from-stream-while-ok/48854> + Err(BroadcastStreamRecvError::Lagged(_)) => None, + }) + }) + } + + fn sender(&self) -> Sender<M> { + self.senders.lock().unwrap().clone() + } + + fn make_sender() -> Sender<M> { + // Queue depth of 16 chosen entirely arbitrarily. Don't read too much + // into it. + let (tx, _) = channel(16); + tx + } +} |
