summaryrefslogtreecommitdiff
path: root/src/channel/app.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/channel/app.rs')
-rw-r--r--src/channel/app.rs141
1 files changed, 136 insertions, 5 deletions
diff --git a/src/channel/app.rs b/src/channel/app.rs
index 84822cb..7b02300 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -1,21 +1,152 @@
+use std::collections::{hash_map::Entry, HashMap};
+use std::sync::{Arc, Mutex};
+
+use futures::{
+ stream::{self, StreamExt as _, TryStreamExt as _},
+ Stream,
+};
use sqlx::sqlite::SqlitePool;
+use tokio::sync::broadcast::{channel, Sender};
+use tokio_stream::wrappers::BroadcastStream;
-use super::repo::Provider as _;
-use crate::error::BoxedError;
+use super::repo::{
+ channels::{Id as ChannelId, Provider as _},
+ messages::{Message, Provider as _},
+};
+use crate::{clock::DateTime, error::BoxedError, login::repo::logins::Login};
pub struct Channels<'a> {
db: &'a SqlitePool,
+ broadcaster: &'a Broadcaster,
}
impl<'a> Channels<'a> {
- pub fn new(db: &'a SqlitePool) -> Self {
- Self { db }
+ pub fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self {
+ Self { db, broadcaster }
}
pub async fn create(&self, name: &str) -> Result<(), BoxedError> {
let mut tx = self.db.begin().await?;
- tx.channels().create(name).await?;
+ let channel = tx.channels().create(name).await?;
+ tx.commit().await?;
+
+ self.broadcaster.register_channel(&channel)?;
+ Ok(())
+ }
+
+ pub async fn send(
+ &self,
+ login: &Login,
+ channel: &ChannelId,
+ body: &str,
+ sent_at: &DateTime,
+ ) -> Result<(), BoxedError> {
+ let mut tx = self.db.begin().await?;
+ let message = tx
+ .messages()
+ .create(&login.id, channel, body, sent_at)
+ .await?;
+ tx.commit().await?;
+
+ self.broadcaster.broadcast(channel, message)?;
+ Ok(())
+ }
+
+ pub async fn events(
+ &self,
+ channel: &ChannelId,
+ ) -> Result<impl Stream<Item = Result<Message, BoxedError>>, BoxedError> {
+ let live_messages = self.broadcaster.listen(channel)?.map_err(BoxedError::from);
+
+ let mut tx = self.db.begin().await?;
+ let stored_messages = tx.messages().all(channel).await?;
+ tx.commit().await?;
+
+ let stored_messages = stream::iter(stored_messages.into_iter().map(Ok));
+
+ Ok(stored_messages.chain(live_messages))
+ }
+}
+
+// Clones will share the same senders collection.
+#[derive(Clone)]
+pub struct Broadcaster {
+ // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
+ // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
+ // lock it must be sync.
+ senders: Arc<Mutex<HashMap<ChannelId, Sender<Message>>>>,
+}
+
+impl Broadcaster {
+ pub async fn from_database(db: &SqlitePool) -> Result<Broadcaster, BoxedError> {
+ let mut tx = db.begin().await?;
+ let channels = tx.channels().all().await?;
tx.commit().await?;
+
+ let channels = channels.iter().map(|c| &c.id);
+ let broadcaster = Broadcaster::new(channels);
+ Ok(broadcaster)
+ }
+
+ fn new<'i>(channels: impl IntoIterator<Item = &'i ChannelId>) -> Self {
+ let senders: HashMap<_, _> = channels
+ .into_iter()
+ .cloned()
+ .map(|id| (id, Self::make_sender()))
+ .collect();
+
+ Self {
+ senders: Arc::new(Mutex::new(senders)),
+ }
+ }
+
+ pub fn register_channel(&self, channel: &ChannelId) -> Result<(), RegisterError> {
+ match self.senders.lock().unwrap().entry(channel.clone()) {
+ Entry::Occupied(_) => Err(RegisterError::Duplicate),
+ vacant => {
+ vacant.or_insert_with(Self::make_sender);
+ Ok(())
+ }
+ }
+ }
+
+ pub fn broadcast(&self, channel: &ChannelId, message: Message) -> Result<(), BroadcastError> {
+ let lock = self.senders.lock().unwrap();
+ let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?;
+
+ // Per the Tokio docs, the returned error is only used to indicate that
+ // there are no receivers. In this use case, that's fine; a lack of
+ // listening consumers (chat clients) when a message is sent isn't an
+ // error.
+ let _ = tx.send(message);
Ok(())
}
+
+ pub fn listen(&self, channel: &ChannelId) -> Result<BroadcastStream<Message>, BroadcastError> {
+ let lock = self.senders.lock().unwrap();
+ let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?;
+ let rx = tx.subscribe();
+ let stream = BroadcastStream::from(rx);
+
+ Ok(stream)
+ }
+
+ fn make_sender() -> Sender<Message> {
+ // Queue depth of 16 chosen entirely arbitrarily. Don't read too much
+ // into it.
+ let (tx, _) = channel(16);
+ tx
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum RegisterError {
+ #[error("duplicate channel registered")]
+ Duplicate,
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum BroadcastError {
+ #[error("requested channel not registered")]
+ Unregistered,
}