summaryrefslogtreecommitdiff
path: root/src/broadcast.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/broadcast.rs')
-rw-r--r--src/broadcast.rs78
1 files changed, 78 insertions, 0 deletions
diff --git a/src/broadcast.rs b/src/broadcast.rs
new file mode 100644
index 0000000..083a301
--- /dev/null
+++ b/src/broadcast.rs
@@ -0,0 +1,78 @@
+use std::sync::{Arc, Mutex};
+
+use futures::{future, stream::StreamExt as _, Stream};
+use tokio::sync::broadcast::{channel, Sender};
+use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
+
+// Clones will share the same sender.
+#[derive(Clone)]
+pub struct Broadcaster<M> {
+ // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
+ // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
+ // lock it must be sync.
+ senders: Arc<Mutex<Sender<M>>>,
+}
+
+impl<M> Default for Broadcaster<M>
+where
+ M: Clone + Send + std::fmt::Debug + 'static,
+{
+ fn default() -> Self {
+ let sender = Self::make_sender();
+
+ Self {
+ senders: Arc::new(Mutex::new(sender)),
+ }
+ }
+}
+
+impl<M> Broadcaster<M>
+where
+ M: Clone + Send + std::fmt::Debug + 'static,
+{
+ // panic: if ``message.channel.id`` has not been previously registered,
+ // and was not part of the initial set of channels.
+ pub fn broadcast(&self, message: &M) {
+ let tx = self.sender();
+
+ // Per the Tokio docs, the returned error is only used to indicate that
+ // there are no receivers. In this use case, that's fine; a lack of
+ // listening consumers (chat clients) when a message is sent isn't an
+ // error.
+ //
+ // The successful return value, which includes the number of active
+ // receivers, also isn't that interesting to us.
+ let _ = tx.send(message.clone());
+ }
+
+ // panic: if ``channel`` has not been previously registered, and was not
+ // part of the initial set of channels.
+ pub fn subscribe(&self) -> impl Stream<Item = M> + std::fmt::Debug {
+ let rx = self.sender().subscribe();
+
+ BroadcastStream::from(rx).scan((), |(), r| {
+ future::ready(match r {
+ Ok(event) => Some(event),
+ // Stop the stream here. This will disconnect SSE clients
+ // (see `routes.rs`), who will then resume from
+ // `Last-Event-ID`, allowing them to catch up by reading
+ // the skipped messages from the database.
+ //
+ // See also:
+ // <https://users.rust-lang.org/t/taking-from-stream-while-ok/48854>
+ Err(BroadcastStreamRecvError::Lagged(_)) => None,
+ })
+ })
+ }
+
+ fn sender(&self) -> Sender<M> {
+ self.senders.lock().unwrap().clone()
+ }
+
+ fn make_sender() -> Sender<M> {
+ // Queue depth of 16 chosen entirely arbitrarily. Don't read too much
+ // into it.
+ let (tx, _) = channel(16);
+ tx
+ }
+}