summaryrefslogtreecommitdiff
path: root/src/channel
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2024-09-13 00:26:03 -0400
committerOwen Jacobson <owen@grimoire.ca>2024-09-13 02:42:27 -0400
commit067e3da1900d052a416c56e1c047640aa23441ae (patch)
tree8baad4240d2532216f2530f5c974479e557c675a /src/channel
parent5d76d0712e07040d9aeeebccb189d75636a07c7a (diff)
Transmit messages via `/:chan/send` and `/:chan/events`.
Diffstat (limited to 'src/channel')
-rw-r--r--src/channel/app.rs141
-rw-r--r--src/channel/repo/channels.rs (renamed from src/channel/repo.rs)11
-rw-r--r--src/channel/repo/messages.rs111
-rw-r--r--src/channel/repo/mod.rs2
-rw-r--r--src/channel/routes.rs56
5 files changed, 306 insertions, 15 deletions
diff --git a/src/channel/app.rs b/src/channel/app.rs
index 84822cb..7b02300 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -1,21 +1,152 @@
+use std::collections::{hash_map::Entry, HashMap};
+use std::sync::{Arc, Mutex};
+
+use futures::{
+ stream::{self, StreamExt as _, TryStreamExt as _},
+ Stream,
+};
use sqlx::sqlite::SqlitePool;
+use tokio::sync::broadcast::{channel, Sender};
+use tokio_stream::wrappers::BroadcastStream;
-use super::repo::Provider as _;
-use crate::error::BoxedError;
+use super::repo::{
+ channels::{Id as ChannelId, Provider as _},
+ messages::{Message, Provider as _},
+};
+use crate::{clock::DateTime, error::BoxedError, login::repo::logins::Login};
pub struct Channels<'a> {
db: &'a SqlitePool,
+ broadcaster: &'a Broadcaster,
}
impl<'a> Channels<'a> {
- pub fn new(db: &'a SqlitePool) -> Self {
- Self { db }
+ pub fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self {
+ Self { db, broadcaster }
}
pub async fn create(&self, name: &str) -> Result<(), BoxedError> {
let mut tx = self.db.begin().await?;
- tx.channels().create(name).await?;
+ let channel = tx.channels().create(name).await?;
+ tx.commit().await?;
+
+ self.broadcaster.register_channel(&channel)?;
+ Ok(())
+ }
+
+ pub async fn send(
+ &self,
+ login: &Login,
+ channel: &ChannelId,
+ body: &str,
+ sent_at: &DateTime,
+ ) -> Result<(), BoxedError> {
+ let mut tx = self.db.begin().await?;
+ let message = tx
+ .messages()
+ .create(&login.id, channel, body, sent_at)
+ .await?;
+ tx.commit().await?;
+
+ self.broadcaster.broadcast(channel, message)?;
+ Ok(())
+ }
+
+ pub async fn events(
+ &self,
+ channel: &ChannelId,
+ ) -> Result<impl Stream<Item = Result<Message, BoxedError>>, BoxedError> {
+ let live_messages = self.broadcaster.listen(channel)?.map_err(BoxedError::from);
+
+ let mut tx = self.db.begin().await?;
+ let stored_messages = tx.messages().all(channel).await?;
+ tx.commit().await?;
+
+ let stored_messages = stream::iter(stored_messages.into_iter().map(Ok));
+
+ Ok(stored_messages.chain(live_messages))
+ }
+}
+
+// Clones will share the same senders collection.
+#[derive(Clone)]
+pub struct Broadcaster {
+ // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
+ // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
+ // lock it must be sync.
+ senders: Arc<Mutex<HashMap<ChannelId, Sender<Message>>>>,
+}
+
+impl Broadcaster {
+ pub async fn from_database(db: &SqlitePool) -> Result<Broadcaster, BoxedError> {
+ let mut tx = db.begin().await?;
+ let channels = tx.channels().all().await?;
tx.commit().await?;
+
+ let channels = channels.iter().map(|c| &c.id);
+ let broadcaster = Broadcaster::new(channels);
+ Ok(broadcaster)
+ }
+
+ fn new<'i>(channels: impl IntoIterator<Item = &'i ChannelId>) -> Self {
+ let senders: HashMap<_, _> = channels
+ .into_iter()
+ .cloned()
+ .map(|id| (id, Self::make_sender()))
+ .collect();
+
+ Self {
+ senders: Arc::new(Mutex::new(senders)),
+ }
+ }
+
+ pub fn register_channel(&self, channel: &ChannelId) -> Result<(), RegisterError> {
+ match self.senders.lock().unwrap().entry(channel.clone()) {
+ Entry::Occupied(_) => Err(RegisterError::Duplicate),
+ vacant => {
+ vacant.or_insert_with(Self::make_sender);
+ Ok(())
+ }
+ }
+ }
+
+ pub fn broadcast(&self, channel: &ChannelId, message: Message) -> Result<(), BroadcastError> {
+ let lock = self.senders.lock().unwrap();
+ let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?;
+
+ // Per the Tokio docs, the returned error is only used to indicate that
+ // there are no receivers. In this use case, that's fine; a lack of
+ // listening consumers (chat clients) when a message is sent isn't an
+ // error.
+ let _ = tx.send(message);
Ok(())
}
+
+ pub fn listen(&self, channel: &ChannelId) -> Result<BroadcastStream<Message>, BroadcastError> {
+ let lock = self.senders.lock().unwrap();
+ let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?;
+ let rx = tx.subscribe();
+ let stream = BroadcastStream::from(rx);
+
+ Ok(stream)
+ }
+
+ fn make_sender() -> Sender<Message> {
+ // Queue depth of 16 chosen entirely arbitrarily. Don't read too much
+ // into it.
+ let (tx, _) = channel(16);
+ tx
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum RegisterError {
+ #[error("duplicate channel registered")]
+ Duplicate,
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum BroadcastError {
+ #[error("requested channel not registered")]
+ Unregistered,
}
diff --git a/src/channel/repo.rs b/src/channel/repo/channels.rs
index a04cac5..6fb0c23 100644
--- a/src/channel/repo.rs
+++ b/src/channel/repo/channels.rs
@@ -25,22 +25,23 @@ pub struct Channel {
impl<'c> Channels<'c> {
/// Create a new channel.
- pub async fn create(&mut self, name: &str) -> Result<(), BoxedError> {
+ pub async fn create(&mut self, name: &str) -> Result<Id, BoxedError> {
let id = Id::generate();
- sqlx::query!(
+ let channel = sqlx::query_scalar!(
r#"
insert
into channel (id, name)
values ($1, $2)
+ returning id as "id: Id"
"#,
id,
name,
)
- .execute(&mut *self.0)
+ .fetch_one(&mut *self.0)
.await?;
- Ok(())
+ Ok(channel)
}
pub async fn all(&mut self) -> Result<Vec<Channel>, BoxedError> {
@@ -62,7 +63,7 @@ impl<'c> Channels<'c> {
}
/// Stable identifier for a [Channel]. Prefixed with `C`.
-#[derive(Debug, sqlx::Type, serde::Deserialize)]
+#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)]
#[sqlx(transparent)]
#[serde(transparent)]
pub struct Id(BaseId);
diff --git a/src/channel/repo/messages.rs b/src/channel/repo/messages.rs
new file mode 100644
index 0000000..bdb0d29
--- /dev/null
+++ b/src/channel/repo/messages.rs
@@ -0,0 +1,111 @@
+use std::fmt;
+
+use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
+
+use super::channels::Id as ChannelId;
+use crate::{
+ clock::DateTime, error::BoxedError, id::Id as BaseId, login::repo::logins::Id as LoginId,
+};
+
+pub trait Provider {
+ fn messages(&mut self) -> Messages;
+}
+
+impl<'c> Provider for Transaction<'c, Sqlite> {
+ fn messages(&mut self) -> Messages {
+ Messages(self)
+ }
+}
+
+pub struct Messages<'t>(&'t mut SqliteConnection);
+
+#[derive(Clone, Debug, serde::Serialize)]
+pub struct Message {
+ pub id: Id,
+ pub sender: LoginId,
+ pub channel: ChannelId,
+ pub body: String,
+ pub sent_at: DateTime,
+}
+
+impl<'c> Messages<'c> {
+ pub async fn create(
+ &mut self,
+ sender: &LoginId,
+ channel: &ChannelId,
+ body: &str,
+ sent_at: &DateTime,
+ ) -> Result<Message, BoxedError> {
+ let id = Id::generate();
+
+ let message = sqlx::query_as!(
+ Message,
+ r#"
+ insert into message
+ (id, sender, channel, body, sent_at)
+ values ($1, $2, $3, $4, $5)
+ returning
+ id as "id: Id",
+ sender as "sender: LoginId",
+ channel as "channel: ChannelId",
+ body,
+ sent_at as "sent_at: DateTime"
+ "#,
+ id,
+ sender,
+ channel,
+ body,
+ sent_at,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(message)
+ }
+
+ pub async fn all(&mut self, channel: &ChannelId) -> Result<Vec<Message>, BoxedError> {
+ let messages = sqlx::query_as!(
+ Message,
+ r#"
+ select
+ id as "id: Id",
+ sender as "sender: LoginId",
+ channel as "channel: ChannelId",
+ body,
+ sent_at as "sent_at: DateTime"
+ from message
+ where channel = $1
+ order by sent_at asc
+ "#,
+ channel,
+ )
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ Ok(messages)
+ }
+}
+
+/// Stable identifier for a [Message]. Prefixed with `M`.
+#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)]
+#[sqlx(transparent)]
+#[serde(transparent)]
+pub struct Id(BaseId);
+
+impl From<BaseId> for Id {
+ fn from(id: BaseId) -> Self {
+ Self(id)
+ }
+}
+
+impl Id {
+ pub fn generate() -> Self {
+ BaseId::generate("M")
+ }
+}
+
+impl fmt::Display for Id {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
diff --git a/src/channel/repo/mod.rs b/src/channel/repo/mod.rs
new file mode 100644
index 0000000..345897d
--- /dev/null
+++ b/src/channel/repo/mod.rs
@@ -0,0 +1,2 @@
+pub mod channels;
+pub mod messages;
diff --git a/src/channel/routes.rs b/src/channel/routes.rs
index 864f1b3..83c733c 100644
--- a/src/channel/routes.rs
+++ b/src/channel/routes.rs
@@ -1,14 +1,23 @@
use axum::{
- extract::{Form, State},
- response::{IntoResponse, Redirect},
- routing::post,
+ extract::{Form, Path, State},
+ http::StatusCode,
+ response::{
+ sse::{self, Sse},
+ IntoResponse, Redirect,
+ },
+ routing::{get, post},
Router,
};
+use futures::stream::{StreamExt as _, TryStreamExt as _};
-use crate::{app::App, error::InternalError, login::repo::logins::Login};
+use super::repo::channels::Id as ChannelId;
+use crate::{app::App, clock::RequestedAt, error::InternalError, login::repo::logins::Login};
pub fn router() -> Router<App> {
- Router::new().route("/create", post(on_create))
+ Router::new()
+ .route("/create", post(on_create))
+ .route("/:channel/send", post(on_send))
+ .route("/:channel/events", get(on_events))
}
#[derive(serde::Deserialize)]
@@ -25,3 +34,40 @@ async fn on_create(
Ok(Redirect::to("/"))
}
+
+#[derive(serde::Deserialize)]
+struct SendRequest {
+ message: String,
+}
+
+async fn on_send(
+ Path(channel): Path<ChannelId>,
+ RequestedAt(sent_at): RequestedAt,
+ State(app): State<App>,
+ login: Login,
+ Form(form): Form<SendRequest>,
+) -> Result<impl IntoResponse, InternalError> {
+ app.channels()
+ .send(&login, &channel, &form.message, &sent_at)
+ .await?;
+
+ Ok(StatusCode::ACCEPTED)
+}
+
+async fn on_events(
+ Path(channel): Path<ChannelId>,
+ State(app): State<App>,
+ _: Login, // requires auth, but doesn't actually care who you are
+) -> Result<impl IntoResponse, InternalError> {
+ let stream = app
+ .channels()
+ .events(&channel)
+ .await?
+ .map(|msg| match msg {
+ Ok(msg) => Ok(serde_json::to_string(&msg)?),
+ Err(err) => Err(err),
+ })
+ .map_ok(|msg| sse::Event::default().data(&msg));
+
+ Ok(Sse::new(stream).keep_alive(sse::KeepAlive::default()))
+}