diff options
Diffstat (limited to 'src/channel/app.rs')
| -rw-r--r-- | src/channel/app.rs | 141 |
1 files changed, 136 insertions, 5 deletions
diff --git a/src/channel/app.rs b/src/channel/app.rs index 84822cb..7b02300 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -1,21 +1,152 @@ +use std::collections::{hash_map::Entry, HashMap}; +use std::sync::{Arc, Mutex}; + +use futures::{ + stream::{self, StreamExt as _, TryStreamExt as _}, + Stream, +}; use sqlx::sqlite::SqlitePool; +use tokio::sync::broadcast::{channel, Sender}; +use tokio_stream::wrappers::BroadcastStream; -use super::repo::Provider as _; -use crate::error::BoxedError; +use super::repo::{ + channels::{Id as ChannelId, Provider as _}, + messages::{Message, Provider as _}, +}; +use crate::{clock::DateTime, error::BoxedError, login::repo::logins::Login}; pub struct Channels<'a> { db: &'a SqlitePool, + broadcaster: &'a Broadcaster, } impl<'a> Channels<'a> { - pub fn new(db: &'a SqlitePool) -> Self { - Self { db } + pub fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self { + Self { db, broadcaster } } pub async fn create(&self, name: &str) -> Result<(), BoxedError> { let mut tx = self.db.begin().await?; - tx.channels().create(name).await?; + let channel = tx.channels().create(name).await?; + tx.commit().await?; + + self.broadcaster.register_channel(&channel)?; + Ok(()) + } + + pub async fn send( + &self, + login: &Login, + channel: &ChannelId, + body: &str, + sent_at: &DateTime, + ) -> Result<(), BoxedError> { + let mut tx = self.db.begin().await?; + let message = tx + .messages() + .create(&login.id, channel, body, sent_at) + .await?; + tx.commit().await?; + + self.broadcaster.broadcast(channel, message)?; + Ok(()) + } + + pub async fn events( + &self, + channel: &ChannelId, + ) -> Result<impl Stream<Item = Result<Message, BoxedError>>, BoxedError> { + let live_messages = self.broadcaster.listen(channel)?.map_err(BoxedError::from); + + let mut tx = self.db.begin().await?; + let stored_messages = tx.messages().all(channel).await?; + tx.commit().await?; + + let stored_messages = stream::iter(stored_messages.into_iter().map(Ok)); + + Ok(stored_messages.chain(live_messages)) + } +} + +// Clones will share the same senders collection. +#[derive(Clone)] +pub struct Broadcaster { + // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's + // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that + // lock it must be sync. + senders: Arc<Mutex<HashMap<ChannelId, Sender<Message>>>>, +} + +impl Broadcaster { + pub async fn from_database(db: &SqlitePool) -> Result<Broadcaster, BoxedError> { + let mut tx = db.begin().await?; + let channels = tx.channels().all().await?; tx.commit().await?; + + let channels = channels.iter().map(|c| &c.id); + let broadcaster = Broadcaster::new(channels); + Ok(broadcaster) + } + + fn new<'i>(channels: impl IntoIterator<Item = &'i ChannelId>) -> Self { + let senders: HashMap<_, _> = channels + .into_iter() + .cloned() + .map(|id| (id, Self::make_sender())) + .collect(); + + Self { + senders: Arc::new(Mutex::new(senders)), + } + } + + pub fn register_channel(&self, channel: &ChannelId) -> Result<(), RegisterError> { + match self.senders.lock().unwrap().entry(channel.clone()) { + Entry::Occupied(_) => Err(RegisterError::Duplicate), + vacant => { + vacant.or_insert_with(Self::make_sender); + Ok(()) + } + } + } + + pub fn broadcast(&self, channel: &ChannelId, message: Message) -> Result<(), BroadcastError> { + let lock = self.senders.lock().unwrap(); + let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?; + + // Per the Tokio docs, the returned error is only used to indicate that + // there are no receivers. In this use case, that's fine; a lack of + // listening consumers (chat clients) when a message is sent isn't an + // error. + let _ = tx.send(message); Ok(()) } + + pub fn listen(&self, channel: &ChannelId) -> Result<BroadcastStream<Message>, BroadcastError> { + let lock = self.senders.lock().unwrap(); + let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?; + let rx = tx.subscribe(); + let stream = BroadcastStream::from(rx); + + Ok(stream) + } + + fn make_sender() -> Sender<Message> { + // Queue depth of 16 chosen entirely arbitrarily. Don't read too much + // into it. + let (tx, _) = channel(16); + tx + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RegisterError { + #[error("duplicate channel registered")] + Duplicate, +} + +#[derive(Debug, thiserror::Error)] +pub enum BroadcastError { + #[error("requested channel not registered")] + Unregistered, } |
