use std::collections::{hash_map::Entry, HashMap}; use std::sync::{Arc, Mutex, MutexGuard}; use futures::{ future, stream::{self, StreamExt as _}, Stream, }; use sqlx::sqlite::SqlitePool; use tokio::sync::broadcast::{channel, Sender}; use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; use super::repo::broadcast::{self, Provider as _}; use crate::{ clock::DateTime, repo::{ channel::{self, Channel, Provider as _}, error::NotFound as _, login::Login, }, }; pub struct Channels<'a> { db: &'a SqlitePool, broadcaster: &'a Broadcaster, } impl<'a> Channels<'a> { pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self { Self { db, broadcaster } } pub async fn create(&self, name: &str) -> Result<(), InternalError> { let mut tx = self.db.begin().await?; let channel = tx.channels().create(name).await?; self.broadcaster.register_channel(&channel); tx.commit().await?; Ok(()) } pub async fn all(&self) -> Result, InternalError> { let mut tx = self.db.begin().await?; let channels = tx.channels().all().await?; tx.commit().await?; Ok(channels) } pub async fn send( &self, login: &Login, channel: &channel::Id, body: &str, sent_at: &DateTime, ) -> Result<(), EventsError> { let mut tx = self.db.begin().await?; let channel = tx .channels() .by_id(channel) .await .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; let message = tx .broadcast() .create(login, &channel, body, sent_at) .await?; tx.commit().await?; self.broadcaster.broadcast(&channel.id, message); Ok(()) } pub async fn events( &self, channel: &channel::Id, resume_at: Option<&DateTime>, ) -> Result + 'static, EventsError> { fn skip_stale( resume_at: Option<&DateTime>, ) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready { let resume_at = resume_at.cloned(); move |msg| { future::ready(match resume_at { None => false, Some(resume_at) => msg.sent_at <= resume_at, }) } } let mut tx = self .db .begin() .await .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; let channel = tx.channels().by_id(channel).await?; let live_messages = self .broadcaster .listen(&channel.id) .skip_while(skip_stale(resume_at)); let stored_messages = tx.broadcast().replay(&channel, resume_at).await?; tx.commit().await?; let stored_messages = stream::iter(stored_messages); Ok(stored_messages.chain(live_messages)) } } #[derive(Debug, thiserror::Error)] pub enum InternalError { #[error(transparent)] DatabaseError(#[from] sqlx::Error), } #[derive(Debug, thiserror::Error)] pub enum EventsError { #[error("channel {0} not found")] ChannelNotFound(channel::Id), #[error(transparent)] DatabaseError(#[from] sqlx::Error), } // Clones will share the same senders collection. #[derive(Clone)] pub struct Broadcaster { // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's // own advice: . Methods that // lock it must be sync. senders: Arc>>>, } impl Broadcaster { pub async fn from_database(db: &SqlitePool) -> Result { let mut tx = db.begin().await?; let channels = tx.channels().all().await?; tx.commit().await?; let channels = channels.iter().map(|c| &c.id); let broadcaster = Self::new(channels); Ok(broadcaster) } fn new<'i>(channels: impl IntoIterator) -> Self { let senders: HashMap<_, _> = channels .into_iter() .cloned() .map(|id| (id, Self::make_sender())) .collect(); Self { senders: Arc::new(Mutex::new(senders)), } } // panic: if ``channel`` is already registered. pub fn register_channel(&self, channel: &channel::Id) { match self.senders().entry(channel.clone()) { // This ever happening indicates a serious logic error. Entry::Occupied(_) => panic!("duplicate channel registration for channel {channel}"), Entry::Vacant(entry) => { entry.insert(Self::make_sender()); } } } // panic: if ``channel`` has not been previously registered, and was not // part of the initial set of channels. pub fn broadcast(&self, channel: &channel::Id, message: broadcast::Message) { let tx = self.sender(channel); // Per the Tokio docs, the returned error is only used to indicate that // there are no receivers. In this use case, that's fine; a lack of // listening consumers (chat clients) when a message is sent isn't an // error. // // The successful return value, which includes the number of active // receivers, also isn't that interesting to us. let _ = tx.send(message); } // panic: if ``channel`` has not been previously registered, and was not // part of the initial set of channels. pub fn listen(&self, channel: &channel::Id) -> impl Stream { let rx = self.sender(channel).subscribe(); BroadcastStream::from(rx) .take_while(|r| { future::ready(match r { Ok(_) => true, // Stop the stream here. This will disconnect SSE clients // (see `routes.rs`), who will then resume from // `Last-Event-ID`, allowing them to catch up by reading // the skipped messages from the database. Err(BroadcastStreamRecvError::Lagged(_)) => false, }) }) .map(|r| { // Since the previous transform stops at the first error, this // should always hold. // // See also . debug_assert!(r.is_ok()); r.unwrap() }) } // panic: if ``channel`` has not been previously registered, and was not // part of the initial set of channels. fn sender(&self, channel: &channel::Id) -> Sender { self.senders()[channel].clone() } fn senders(&self) -> MutexGuard>> { self.senders.lock().unwrap() // propagate panics when mutex is poisoned } fn make_sender() -> Sender { // Queue depth of 16 chosen entirely arbitrarily. Don't read too much // into it. let (tx, _) = channel(16); tx } }