From 067e3da1900d052a416c56e1c047640aa23441ae Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Fri, 13 Sep 2024 00:26:03 -0400 Subject: Transmit messages via `/:chan/send` and `/:chan/events`. --- src/app.rs | 16 +++-- src/channel/app.rs | 141 +++++++++++++++++++++++++++++++++++++++++-- src/channel/repo.rs | 86 -------------------------- src/channel/repo/channels.rs | 87 ++++++++++++++++++++++++++ src/channel/repo/messages.rs | 111 ++++++++++++++++++++++++++++++++++ src/channel/repo/mod.rs | 2 + src/channel/routes.rs | 56 +++++++++++++++-- src/cli.rs | 3 +- src/id.rs | 2 +- src/index/app.rs | 2 +- src/index/templates.rs | 2 +- src/login/repo/logins.rs | 2 +- 12 files changed, 405 insertions(+), 105 deletions(-) delete mode 100644 src/channel/repo.rs create mode 100644 src/channel/repo/channels.rs create mode 100644 src/channel/repo/messages.rs create mode 100644 src/channel/repo/mod.rs (limited to 'src') diff --git a/src/app.rs b/src/app.rs index 4195fdc..f349fd4 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,15 +1,23 @@ use sqlx::sqlite::SqlitePool; -use crate::{channel::app::Channels, index::app::Index, login::app::Logins}; +use crate::error::BoxedError; + +use crate::{ + channel::app::{Broadcaster, Channels}, + index::app::Index, + login::app::Logins, +}; #[derive(Clone)] pub struct App { db: SqlitePool, + broadcaster: Broadcaster, } impl App { - pub fn from(db: SqlitePool) -> Self { - Self { db } + pub async fn from(db: SqlitePool) -> Result { + let broadcaster = Broadcaster::from_database(&db).await?; + Ok(Self { db, broadcaster }) } } @@ -23,6 +31,6 @@ impl App { } pub fn channels(&self) -> Channels { - Channels::new(&self.db) + Channels::new(&self.db, &self.broadcaster) } } diff --git a/src/channel/app.rs b/src/channel/app.rs index 84822cb..7b02300 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -1,21 +1,152 @@ +use std::collections::{hash_map::Entry, HashMap}; +use std::sync::{Arc, Mutex}; + +use futures::{ + stream::{self, StreamExt as _, TryStreamExt as _}, + Stream, +}; use sqlx::sqlite::SqlitePool; +use tokio::sync::broadcast::{channel, Sender}; +use tokio_stream::wrappers::BroadcastStream; -use super::repo::Provider as _; -use crate::error::BoxedError; +use super::repo::{ + channels::{Id as ChannelId, Provider as _}, + messages::{Message, Provider as _}, +}; +use crate::{clock::DateTime, error::BoxedError, login::repo::logins::Login}; pub struct Channels<'a> { db: &'a SqlitePool, + broadcaster: &'a Broadcaster, } impl<'a> Channels<'a> { - pub fn new(db: &'a SqlitePool) -> Self { - Self { db } + pub fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self { + Self { db, broadcaster } } pub async fn create(&self, name: &str) -> Result<(), BoxedError> { let mut tx = self.db.begin().await?; - tx.channels().create(name).await?; + let channel = tx.channels().create(name).await?; + tx.commit().await?; + + self.broadcaster.register_channel(&channel)?; + Ok(()) + } + + pub async fn send( + &self, + login: &Login, + channel: &ChannelId, + body: &str, + sent_at: &DateTime, + ) -> Result<(), BoxedError> { + let mut tx = self.db.begin().await?; + let message = tx + .messages() + .create(&login.id, channel, body, sent_at) + .await?; + tx.commit().await?; + + self.broadcaster.broadcast(channel, message)?; + Ok(()) + } + + pub async fn events( + &self, + channel: &ChannelId, + ) -> Result>, BoxedError> { + let live_messages = self.broadcaster.listen(channel)?.map_err(BoxedError::from); + + let mut tx = self.db.begin().await?; + let stored_messages = tx.messages().all(channel).await?; + tx.commit().await?; + + let stored_messages = stream::iter(stored_messages.into_iter().map(Ok)); + + Ok(stored_messages.chain(live_messages)) + } +} + +// Clones will share the same senders collection. +#[derive(Clone)] +pub struct Broadcaster { + // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's + // own advice: . Methods that + // lock it must be sync. + senders: Arc>>>, +} + +impl Broadcaster { + pub async fn from_database(db: &SqlitePool) -> Result { + let mut tx = db.begin().await?; + let channels = tx.channels().all().await?; tx.commit().await?; + + let channels = channels.iter().map(|c| &c.id); + let broadcaster = Broadcaster::new(channels); + Ok(broadcaster) + } + + fn new<'i>(channels: impl IntoIterator) -> Self { + let senders: HashMap<_, _> = channels + .into_iter() + .cloned() + .map(|id| (id, Self::make_sender())) + .collect(); + + Self { + senders: Arc::new(Mutex::new(senders)), + } + } + + pub fn register_channel(&self, channel: &ChannelId) -> Result<(), RegisterError> { + match self.senders.lock().unwrap().entry(channel.clone()) { + Entry::Occupied(_) => Err(RegisterError::Duplicate), + vacant => { + vacant.or_insert_with(Self::make_sender); + Ok(()) + } + } + } + + pub fn broadcast(&self, channel: &ChannelId, message: Message) -> Result<(), BroadcastError> { + let lock = self.senders.lock().unwrap(); + let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?; + + // Per the Tokio docs, the returned error is only used to indicate that + // there are no receivers. In this use case, that's fine; a lack of + // listening consumers (chat clients) when a message is sent isn't an + // error. + let _ = tx.send(message); Ok(()) } + + pub fn listen(&self, channel: &ChannelId) -> Result, BroadcastError> { + let lock = self.senders.lock().unwrap(); + let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?; + let rx = tx.subscribe(); + let stream = BroadcastStream::from(rx); + + Ok(stream) + } + + fn make_sender() -> Sender { + // Queue depth of 16 chosen entirely arbitrarily. Don't read too much + // into it. + let (tx, _) = channel(16); + tx + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RegisterError { + #[error("duplicate channel registered")] + Duplicate, +} + +#[derive(Debug, thiserror::Error)] +pub enum BroadcastError { + #[error("requested channel not registered")] + Unregistered, } diff --git a/src/channel/repo.rs b/src/channel/repo.rs deleted file mode 100644 index a04cac5..0000000 --- a/src/channel/repo.rs +++ /dev/null @@ -1,86 +0,0 @@ -use std::fmt; - -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::error::BoxedError; -use crate::id::Id as BaseId; - -pub trait Provider { - fn channels(&mut self) -> Channels; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn channels(&mut self) -> Channels { - Channels(self) - } -} - -pub struct Channels<'t>(&'t mut SqliteConnection); - -#[derive(Debug)] -pub struct Channel { - pub id: Id, - pub name: String, -} - -impl<'c> Channels<'c> { - /// Create a new channel. - pub async fn create(&mut self, name: &str) -> Result<(), BoxedError> { - let id = Id::generate(); - - sqlx::query!( - r#" - insert - into channel (id, name) - values ($1, $2) - "#, - id, - name, - ) - .execute(&mut *self.0) - .await?; - - Ok(()) - } - - pub async fn all(&mut self) -> Result, BoxedError> { - let channels = sqlx::query_as!( - Channel, - r#" - select - channel.id as "id: Id", - channel.name - from channel - order by channel.name - "#, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(channels) - } -} - -/// Stable identifier for a [Channel]. Prefixed with `C`. -#[derive(Debug, sqlx::Type, serde::Deserialize)] -#[sqlx(transparent)] -#[serde(transparent)] -pub struct Id(BaseId); - -impl From for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("C") - } -} - -impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} diff --git a/src/channel/repo/channels.rs b/src/channel/repo/channels.rs new file mode 100644 index 0000000..6fb0c23 --- /dev/null +++ b/src/channel/repo/channels.rs @@ -0,0 +1,87 @@ +use std::fmt; + +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::error::BoxedError; +use crate::id::Id as BaseId; + +pub trait Provider { + fn channels(&mut self) -> Channels; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn channels(&mut self) -> Channels { + Channels(self) + } +} + +pub struct Channels<'t>(&'t mut SqliteConnection); + +#[derive(Debug)] +pub struct Channel { + pub id: Id, + pub name: String, +} + +impl<'c> Channels<'c> { + /// Create a new channel. + pub async fn create(&mut self, name: &str) -> Result { + let id = Id::generate(); + + let channel = sqlx::query_scalar!( + r#" + insert + into channel (id, name) + values ($1, $2) + returning id as "id: Id" + "#, + id, + name, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(channel) + } + + pub async fn all(&mut self) -> Result, BoxedError> { + let channels = sqlx::query_as!( + Channel, + r#" + select + channel.id as "id: Id", + channel.name + from channel + order by channel.name + "#, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } +} + +/// Stable identifier for a [Channel]. Prefixed with `C`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("C") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/channel/repo/messages.rs b/src/channel/repo/messages.rs new file mode 100644 index 0000000..bdb0d29 --- /dev/null +++ b/src/channel/repo/messages.rs @@ -0,0 +1,111 @@ +use std::fmt; + +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use super::channels::Id as ChannelId; +use crate::{ + clock::DateTime, error::BoxedError, id::Id as BaseId, login::repo::logins::Id as LoginId, +}; + +pub trait Provider { + fn messages(&mut self) -> Messages; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn messages(&mut self) -> Messages { + Messages(self) + } +} + +pub struct Messages<'t>(&'t mut SqliteConnection); + +#[derive(Clone, Debug, serde::Serialize)] +pub struct Message { + pub id: Id, + pub sender: LoginId, + pub channel: ChannelId, + pub body: String, + pub sent_at: DateTime, +} + +impl<'c> Messages<'c> { + pub async fn create( + &mut self, + sender: &LoginId, + channel: &ChannelId, + body: &str, + sent_at: &DateTime, + ) -> Result { + let id = Id::generate(); + + let message = sqlx::query_as!( + Message, + r#" + insert into message + (id, sender, channel, body, sent_at) + values ($1, $2, $3, $4, $5) + returning + id as "id: Id", + sender as "sender: LoginId", + channel as "channel: ChannelId", + body, + sent_at as "sent_at: DateTime" + "#, + id, + sender, + channel, + body, + sent_at, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn all(&mut self, channel: &ChannelId) -> Result, BoxedError> { + let messages = sqlx::query_as!( + Message, + r#" + select + id as "id: Id", + sender as "sender: LoginId", + channel as "channel: ChannelId", + body, + sent_at as "sent_at: DateTime" + from message + where channel = $1 + order by sent_at asc + "#, + channel, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } +} + +/// Stable identifier for a [Message]. Prefixed with `M`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("M") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/channel/repo/mod.rs b/src/channel/repo/mod.rs new file mode 100644 index 0000000..345897d --- /dev/null +++ b/src/channel/repo/mod.rs @@ -0,0 +1,2 @@ +pub mod channels; +pub mod messages; diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 864f1b3..83c733c 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -1,14 +1,23 @@ use axum::{ - extract::{Form, State}, - response::{IntoResponse, Redirect}, - routing::post, + extract::{Form, Path, State}, + http::StatusCode, + response::{ + sse::{self, Sse}, + IntoResponse, Redirect, + }, + routing::{get, post}, Router, }; +use futures::stream::{StreamExt as _, TryStreamExt as _}; -use crate::{app::App, error::InternalError, login::repo::logins::Login}; +use super::repo::channels::Id as ChannelId; +use crate::{app::App, clock::RequestedAt, error::InternalError, login::repo::logins::Login}; pub fn router() -> Router { - Router::new().route("/create", post(on_create)) + Router::new() + .route("/create", post(on_create)) + .route("/:channel/send", post(on_send)) + .route("/:channel/events", get(on_events)) } #[derive(serde::Deserialize)] @@ -25,3 +34,40 @@ async fn on_create( Ok(Redirect::to("/")) } + +#[derive(serde::Deserialize)] +struct SendRequest { + message: String, +} + +async fn on_send( + Path(channel): Path, + RequestedAt(sent_at): RequestedAt, + State(app): State, + login: Login, + Form(form): Form, +) -> Result { + app.channels() + .send(&login, &channel, &form.message, &sent_at) + .await?; + + Ok(StatusCode::ACCEPTED) +} + +async fn on_events( + Path(channel): Path, + State(app): State, + _: Login, // requires auth, but doesn't actually care who you are +) -> Result { + let stream = app + .channels() + .events(&channel) + .await? + .map(|msg| match msg { + Ok(msg) => Ok(serde_json::to_string(&msg)?), + Err(err) => Err(err), + }) + .map_ok(|msg| sse::Event::default().data(&msg)); + + Ok(Sse::new(stream).keep_alive(sse::KeepAlive::default())) +} diff --git a/src/cli.rs b/src/cli.rs index e374834..fa7c499 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -28,9 +28,10 @@ impl Args { sqlx::migrate!().run(&pool).await?; + let app = App::from(pool).await?; let app = routers() .route_layer(middleware::from_fn(clock::middleware)) - .with_state(App::from(pool)); + .with_state(app); let listener = self.listener().await?; let started_msg = started_msg(&listener)?; diff --git a/src/id.rs b/src/id.rs index 4e12f2a..c69b341 100644 --- a/src/id.rs +++ b/src/id.rs @@ -27,7 +27,7 @@ pub const ID_SIZE: usize = 15; // // By convention, the prefix should be UPPERCASE - note that the alphabet for this // is entirely lowercase. -#[derive(Debug, Hash, PartialEq, Eq, sqlx::Type, serde::Deserialize)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, sqlx::Type, serde::Deserialize, serde::Serialize)] #[sqlx(transparent)] #[serde(transparent)] pub struct Id(String); diff --git a/src/index/app.rs b/src/index/app.rs index 79f5a9a..d6eef18 100644 --- a/src/index/app.rs +++ b/src/index/app.rs @@ -1,7 +1,7 @@ use sqlx::sqlite::SqlitePool; use crate::{ - channel::repo::{Channel, Provider as _}, + channel::repo::channels::{Channel, Provider as _}, error::BoxedError, }; diff --git a/src/index/templates.rs b/src/index/templates.rs index fdb750b..38cd93f 100644 --- a/src/index/templates.rs +++ b/src/index/templates.rs @@ -1,6 +1,6 @@ use maud::{html, Markup, DOCTYPE}; -use crate::{channel::repo::Channel, login::repo::logins::Login}; +use crate::{channel::repo::channels::Channel, login::repo::logins::Login}; pub fn authenticated<'c>(login: Login, channels: impl IntoIterator) -> Markup { html! { diff --git a/src/login/repo/logins.rs b/src/login/repo/logins.rs index 26a5b09..142d8fb 100644 --- a/src/login/repo/logins.rs +++ b/src/login/repo/logins.rs @@ -90,7 +90,7 @@ impl<'c> Logins<'c> { } /// Stable identifier for a [Login]. Prefixed with `L`. -#[derive(Debug, sqlx::Type)] +#[derive(Clone, Debug, sqlx::Type, serde::Serialize)] #[sqlx(transparent)] pub struct Id(BaseId); -- cgit v1.2.3