use std::collections::{hash_map::Entry, HashMap}; use std::sync::{Arc, Mutex}; use futures::{ stream::{self, StreamExt as _, TryStreamExt as _}, Stream, }; use sqlx::sqlite::SqlitePool; use tokio::sync::broadcast::{channel, Sender}; use tokio_stream::wrappers::BroadcastStream; use super::repo::{ channels::{Id as ChannelId, Provider as _}, messages::{Id as MessageId, Message as StoredMessage, Provider as _}, }; use crate::{ clock::DateTime, error::BoxedError, login::repo::logins::{Id as LoginId, Login, Provider as _}, }; pub struct Channels<'a> { db: &'a SqlitePool, broadcaster: &'a Broadcaster, } impl<'a> Channels<'a> { pub fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self { Self { db, broadcaster } } pub async fn create(&self, name: &str) -> Result<(), BoxedError> { let mut tx = self.db.begin().await?; let channel = tx.channels().create(name).await?; tx.commit().await?; self.broadcaster.register_channel(&channel)?; Ok(()) } pub async fn send( &self, login: &Login, channel: &ChannelId, body: &str, sent_at: &DateTime, ) -> Result<(), BoxedError> { let mut tx = self.db.begin().await?; let message = tx .messages() .create(&login.id, channel, body, sent_at) .await?; let message = Message::from_login(login, message)?; tx.commit().await?; self.broadcaster.broadcast(channel, message)?; Ok(()) } pub async fn events( &self, channel: &ChannelId, ) -> Result>, BoxedError> { let live_messages = self.broadcaster.listen(channel)?.map_err(BoxedError::from); let db = self.db.clone(); let mut tx = self.db.begin().await?; let stored_messages = tx.messages().all(channel).await?; let stored_messages = stream::iter(stored_messages).then(move |msg| { // The exact series of moves and clones here is the result of trial // and error, and is likely the best I can do, given: // // * This closure _can't_ keep a reference to self, for lifetime // reasons; // * The closure will be executed multiple times, so it can't give // up `db`; and // * The returned future can't keep a reference to `db` as doing // so would allow refs to the closure's `db` to outlive the // closure itself. // // Fortunately, cloning the pool is acceptable - sqlx pools were // designed to be cloned and the only thing actually cloned is a // single `Arc`. This whole chain of clones just ends up producing // cheap handles to a single underlying "real" pool. let db = db.clone(); async move { let mut tx = db.begin().await?; let msg = Message::from_stored(&mut tx, msg).await?; tx.commit().await?; Ok(msg) } }); tx.commit().await?; Ok(stored_messages.chain(live_messages)) } } #[derive(Clone, Debug, serde::Serialize)] pub struct Message { pub id: MessageId, pub sender: Login, pub body: String, pub sent_at: DateTime, } impl Message { async fn from_stored( tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, message: StoredMessage, ) -> Result { let sender = tx.logins().by_id(&message.sender).await?; let message = Self { sender, id: message.id, body: message.body, sent_at: message.sent_at, }; Ok(message) } fn from_login(sender: &Login, message: StoredMessage) -> Result { if sender.id != message.sender { // This functionally can't happen, but the funny thing about "This // can never happen" comments is that they're usually wrong. return Err(MessageError::LoginMismatched { sender: sender.id.clone(), message: message.sender, }); } let message = Self { sender: sender.clone(), id: message.id, body: message.body, sent_at: message.sent_at, }; Ok(message) } } #[derive(Debug, thiserror::Error)] enum MessageError { #[error("sender login id {sender} did not match message login id {message}")] LoginMismatched { sender: LoginId, message: LoginId }, } // Clones will share the same senders collection. #[derive(Clone)] pub struct Broadcaster { // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's // own advice: . Methods that // lock it must be sync. senders: Arc>>>, } impl Broadcaster { pub async fn from_database(db: &SqlitePool) -> Result { let mut tx = db.begin().await?; let channels = tx.channels().all().await?; tx.commit().await?; let channels = channels.iter().map(|c| &c.id); let broadcaster = Broadcaster::new(channels); Ok(broadcaster) } fn new<'i>(channels: impl IntoIterator) -> Self { let senders: HashMap<_, _> = channels .into_iter() .cloned() .map(|id| (id, Self::make_sender())) .collect(); Self { senders: Arc::new(Mutex::new(senders)), } } pub fn register_channel(&self, channel: &ChannelId) -> Result<(), RegisterError> { match self.senders.lock().unwrap().entry(channel.clone()) { Entry::Occupied(_) => Err(RegisterError::Duplicate), vacant => { vacant.or_insert_with(Self::make_sender); Ok(()) } } } pub fn broadcast(&self, channel: &ChannelId, message: Message) -> Result<(), BroadcastError> { let lock = self.senders.lock().unwrap(); let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?; // Per the Tokio docs, the returned error is only used to indicate that // there are no receivers. In this use case, that's fine; a lack of // listening consumers (chat clients) when a message is sent isn't an // error. let _ = tx.send(message); Ok(()) } pub fn listen(&self, channel: &ChannelId) -> Result, BroadcastError> { let lock = self.senders.lock().unwrap(); let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?; let rx = tx.subscribe(); let stream = BroadcastStream::from(rx); Ok(stream) } fn make_sender() -> Sender { // Queue depth of 16 chosen entirely arbitrarily. Don't read too much // into it. let (tx, _) = channel(16); tx } } #[derive(Debug, thiserror::Error)] pub enum RegisterError { #[error("duplicate channel registered")] Duplicate, } #[derive(Debug, thiserror::Error)] pub enum BroadcastError { #[error("requested channel not registered")] Unregistered, }