diff options
| author | Owen Jacobson <owen@grimoire.ca> | 2024-09-13 00:26:03 -0400 |
|---|---|---|
| committer | Owen Jacobson <owen@grimoire.ca> | 2024-09-13 02:42:27 -0400 |
| commit | 067e3da1900d052a416c56e1c047640aa23441ae (patch) | |
| tree | 8baad4240d2532216f2530f5c974479e557c675a /src | |
| parent | 5d76d0712e07040d9aeeebccb189d75636a07c7a (diff) | |
Transmit messages via `/:chan/send` and `/:chan/events`.
Diffstat (limited to 'src')
| -rw-r--r-- | src/app.rs | 16 | ||||
| -rw-r--r-- | src/channel/app.rs | 141 | ||||
| -rw-r--r-- | src/channel/repo/channels.rs (renamed from src/channel/repo.rs) | 11 | ||||
| -rw-r--r-- | src/channel/repo/messages.rs | 111 | ||||
| -rw-r--r-- | src/channel/repo/mod.rs | 2 | ||||
| -rw-r--r-- | src/channel/routes.rs | 56 | ||||
| -rw-r--r-- | src/cli.rs | 3 | ||||
| -rw-r--r-- | src/id.rs | 2 | ||||
| -rw-r--r-- | src/index/app.rs | 2 | ||||
| -rw-r--r-- | src/index/templates.rs | 2 | ||||
| -rw-r--r-- | src/login/repo/logins.rs | 2 |
11 files changed, 324 insertions, 24 deletions
@@ -1,15 +1,23 @@ use sqlx::sqlite::SqlitePool; -use crate::{channel::app::Channels, index::app::Index, login::app::Logins}; +use crate::error::BoxedError; + +use crate::{ + channel::app::{Broadcaster, Channels}, + index::app::Index, + login::app::Logins, +}; #[derive(Clone)] pub struct App { db: SqlitePool, + broadcaster: Broadcaster, } impl App { - pub fn from(db: SqlitePool) -> Self { - Self { db } + pub async fn from(db: SqlitePool) -> Result<Self, BoxedError> { + let broadcaster = Broadcaster::from_database(&db).await?; + Ok(Self { db, broadcaster }) } } @@ -23,6 +31,6 @@ impl App { } pub fn channels(&self) -> Channels { - Channels::new(&self.db) + Channels::new(&self.db, &self.broadcaster) } } diff --git a/src/channel/app.rs b/src/channel/app.rs index 84822cb..7b02300 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -1,21 +1,152 @@ +use std::collections::{hash_map::Entry, HashMap}; +use std::sync::{Arc, Mutex}; + +use futures::{ + stream::{self, StreamExt as _, TryStreamExt as _}, + Stream, +}; use sqlx::sqlite::SqlitePool; +use tokio::sync::broadcast::{channel, Sender}; +use tokio_stream::wrappers::BroadcastStream; -use super::repo::Provider as _; -use crate::error::BoxedError; +use super::repo::{ + channels::{Id as ChannelId, Provider as _}, + messages::{Message, Provider as _}, +}; +use crate::{clock::DateTime, error::BoxedError, login::repo::logins::Login}; pub struct Channels<'a> { db: &'a SqlitePool, + broadcaster: &'a Broadcaster, } impl<'a> Channels<'a> { - pub fn new(db: &'a SqlitePool) -> Self { - Self { db } + pub fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self { + Self { db, broadcaster } } pub async fn create(&self, name: &str) -> Result<(), BoxedError> { let mut tx = self.db.begin().await?; - tx.channels().create(name).await?; + let channel = tx.channels().create(name).await?; + tx.commit().await?; + + self.broadcaster.register_channel(&channel)?; + Ok(()) + } + + pub async fn send( + &self, + login: &Login, + channel: &ChannelId, + body: &str, + sent_at: &DateTime, + ) -> Result<(), BoxedError> { + let mut tx = self.db.begin().await?; + let message = tx + .messages() + .create(&login.id, channel, body, sent_at) + .await?; + tx.commit().await?; + + self.broadcaster.broadcast(channel, message)?; + Ok(()) + } + + pub async fn events( + &self, + channel: &ChannelId, + ) -> Result<impl Stream<Item = Result<Message, BoxedError>>, BoxedError> { + let live_messages = self.broadcaster.listen(channel)?.map_err(BoxedError::from); + + let mut tx = self.db.begin().await?; + let stored_messages = tx.messages().all(channel).await?; + tx.commit().await?; + + let stored_messages = stream::iter(stored_messages.into_iter().map(Ok)); + + Ok(stored_messages.chain(live_messages)) + } +} + +// Clones will share the same senders collection. +#[derive(Clone)] +pub struct Broadcaster { + // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's + // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that + // lock it must be sync. + senders: Arc<Mutex<HashMap<ChannelId, Sender<Message>>>>, +} + +impl Broadcaster { + pub async fn from_database(db: &SqlitePool) -> Result<Broadcaster, BoxedError> { + let mut tx = db.begin().await?; + let channels = tx.channels().all().await?; tx.commit().await?; + + let channels = channels.iter().map(|c| &c.id); + let broadcaster = Broadcaster::new(channels); + Ok(broadcaster) + } + + fn new<'i>(channels: impl IntoIterator<Item = &'i ChannelId>) -> Self { + let senders: HashMap<_, _> = channels + .into_iter() + .cloned() + .map(|id| (id, Self::make_sender())) + .collect(); + + Self { + senders: Arc::new(Mutex::new(senders)), + } + } + + pub fn register_channel(&self, channel: &ChannelId) -> Result<(), RegisterError> { + match self.senders.lock().unwrap().entry(channel.clone()) { + Entry::Occupied(_) => Err(RegisterError::Duplicate), + vacant => { + vacant.or_insert_with(Self::make_sender); + Ok(()) + } + } + } + + pub fn broadcast(&self, channel: &ChannelId, message: Message) -> Result<(), BroadcastError> { + let lock = self.senders.lock().unwrap(); + let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?; + + // Per the Tokio docs, the returned error is only used to indicate that + // there are no receivers. In this use case, that's fine; a lack of + // listening consumers (chat clients) when a message is sent isn't an + // error. + let _ = tx.send(message); Ok(()) } + + pub fn listen(&self, channel: &ChannelId) -> Result<BroadcastStream<Message>, BroadcastError> { + let lock = self.senders.lock().unwrap(); + let tx = lock.get(channel).ok_or(BroadcastError::Unregistered)?; + let rx = tx.subscribe(); + let stream = BroadcastStream::from(rx); + + Ok(stream) + } + + fn make_sender() -> Sender<Message> { + // Queue depth of 16 chosen entirely arbitrarily. Don't read too much + // into it. + let (tx, _) = channel(16); + tx + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RegisterError { + #[error("duplicate channel registered")] + Duplicate, +} + +#[derive(Debug, thiserror::Error)] +pub enum BroadcastError { + #[error("requested channel not registered")] + Unregistered, } diff --git a/src/channel/repo.rs b/src/channel/repo/channels.rs index a04cac5..6fb0c23 100644 --- a/src/channel/repo.rs +++ b/src/channel/repo/channels.rs @@ -25,22 +25,23 @@ pub struct Channel { impl<'c> Channels<'c> { /// Create a new channel. - pub async fn create(&mut self, name: &str) -> Result<(), BoxedError> { + pub async fn create(&mut self, name: &str) -> Result<Id, BoxedError> { let id = Id::generate(); - sqlx::query!( + let channel = sqlx::query_scalar!( r#" insert into channel (id, name) values ($1, $2) + returning id as "id: Id" "#, id, name, ) - .execute(&mut *self.0) + .fetch_one(&mut *self.0) .await?; - Ok(()) + Ok(channel) } pub async fn all(&mut self) -> Result<Vec<Channel>, BoxedError> { @@ -62,7 +63,7 @@ impl<'c> Channels<'c> { } /// Stable identifier for a [Channel]. Prefixed with `C`. -#[derive(Debug, sqlx::Type, serde::Deserialize)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] #[sqlx(transparent)] #[serde(transparent)] pub struct Id(BaseId); diff --git a/src/channel/repo/messages.rs b/src/channel/repo/messages.rs new file mode 100644 index 0000000..bdb0d29 --- /dev/null +++ b/src/channel/repo/messages.rs @@ -0,0 +1,111 @@ +use std::fmt; + +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use super::channels::Id as ChannelId; +use crate::{ + clock::DateTime, error::BoxedError, id::Id as BaseId, login::repo::logins::Id as LoginId, +}; + +pub trait Provider { + fn messages(&mut self) -> Messages; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn messages(&mut self) -> Messages { + Messages(self) + } +} + +pub struct Messages<'t>(&'t mut SqliteConnection); + +#[derive(Clone, Debug, serde::Serialize)] +pub struct Message { + pub id: Id, + pub sender: LoginId, + pub channel: ChannelId, + pub body: String, + pub sent_at: DateTime, +} + +impl<'c> Messages<'c> { + pub async fn create( + &mut self, + sender: &LoginId, + channel: &ChannelId, + body: &str, + sent_at: &DateTime, + ) -> Result<Message, BoxedError> { + let id = Id::generate(); + + let message = sqlx::query_as!( + Message, + r#" + insert into message + (id, sender, channel, body, sent_at) + values ($1, $2, $3, $4, $5) + returning + id as "id: Id", + sender as "sender: LoginId", + channel as "channel: ChannelId", + body, + sent_at as "sent_at: DateTime" + "#, + id, + sender, + channel, + body, + sent_at, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn all(&mut self, channel: &ChannelId) -> Result<Vec<Message>, BoxedError> { + let messages = sqlx::query_as!( + Message, + r#" + select + id as "id: Id", + sender as "sender: LoginId", + channel as "channel: ChannelId", + body, + sent_at as "sent_at: DateTime" + from message + where channel = $1 + order by sent_at asc + "#, + channel, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } +} + +/// Stable identifier for a [Message]. Prefixed with `M`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From<BaseId> for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("M") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/channel/repo/mod.rs b/src/channel/repo/mod.rs new file mode 100644 index 0000000..345897d --- /dev/null +++ b/src/channel/repo/mod.rs @@ -0,0 +1,2 @@ +pub mod channels; +pub mod messages; diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 864f1b3..83c733c 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -1,14 +1,23 @@ use axum::{ - extract::{Form, State}, - response::{IntoResponse, Redirect}, - routing::post, + extract::{Form, Path, State}, + http::StatusCode, + response::{ + sse::{self, Sse}, + IntoResponse, Redirect, + }, + routing::{get, post}, Router, }; +use futures::stream::{StreamExt as _, TryStreamExt as _}; -use crate::{app::App, error::InternalError, login::repo::logins::Login}; +use super::repo::channels::Id as ChannelId; +use crate::{app::App, clock::RequestedAt, error::InternalError, login::repo::logins::Login}; pub fn router() -> Router<App> { - Router::new().route("/create", post(on_create)) + Router::new() + .route("/create", post(on_create)) + .route("/:channel/send", post(on_send)) + .route("/:channel/events", get(on_events)) } #[derive(serde::Deserialize)] @@ -25,3 +34,40 @@ async fn on_create( Ok(Redirect::to("/")) } + +#[derive(serde::Deserialize)] +struct SendRequest { + message: String, +} + +async fn on_send( + Path(channel): Path<ChannelId>, + RequestedAt(sent_at): RequestedAt, + State(app): State<App>, + login: Login, + Form(form): Form<SendRequest>, +) -> Result<impl IntoResponse, InternalError> { + app.channels() + .send(&login, &channel, &form.message, &sent_at) + .await?; + + Ok(StatusCode::ACCEPTED) +} + +async fn on_events( + Path(channel): Path<ChannelId>, + State(app): State<App>, + _: Login, // requires auth, but doesn't actually care who you are +) -> Result<impl IntoResponse, InternalError> { + let stream = app + .channels() + .events(&channel) + .await? + .map(|msg| match msg { + Ok(msg) => Ok(serde_json::to_string(&msg)?), + Err(err) => Err(err), + }) + .map_ok(|msg| sse::Event::default().data(&msg)); + + Ok(Sse::new(stream).keep_alive(sse::KeepAlive::default())) +} @@ -28,9 +28,10 @@ impl Args { sqlx::migrate!().run(&pool).await?; + let app = App::from(pool).await?; let app = routers() .route_layer(middleware::from_fn(clock::middleware)) - .with_state(App::from(pool)); + .with_state(app); let listener = self.listener().await?; let started_msg = started_msg(&listener)?; @@ -27,7 +27,7 @@ pub const ID_SIZE: usize = 15; // // By convention, the prefix should be UPPERCASE - note that the alphabet for this // is entirely lowercase. -#[derive(Debug, Hash, PartialEq, Eq, sqlx::Type, serde::Deserialize)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, sqlx::Type, serde::Deserialize, serde::Serialize)] #[sqlx(transparent)] #[serde(transparent)] pub struct Id(String); diff --git a/src/index/app.rs b/src/index/app.rs index 79f5a9a..d6eef18 100644 --- a/src/index/app.rs +++ b/src/index/app.rs @@ -1,7 +1,7 @@ use sqlx::sqlite::SqlitePool; use crate::{ - channel::repo::{Channel, Provider as _}, + channel::repo::channels::{Channel, Provider as _}, error::BoxedError, }; diff --git a/src/index/templates.rs b/src/index/templates.rs index fdb750b..38cd93f 100644 --- a/src/index/templates.rs +++ b/src/index/templates.rs @@ -1,6 +1,6 @@ use maud::{html, Markup, DOCTYPE}; -use crate::{channel::repo::Channel, login::repo::logins::Login}; +use crate::{channel::repo::channels::Channel, login::repo::logins::Login}; pub fn authenticated<'c>(login: Login, channels: impl IntoIterator<Item = &'c Channel>) -> Markup { html! { diff --git a/src/login/repo/logins.rs b/src/login/repo/logins.rs index 26a5b09..142d8fb 100644 --- a/src/login/repo/logins.rs +++ b/src/login/repo/logins.rs @@ -90,7 +90,7 @@ impl<'c> Logins<'c> { } /// Stable identifier for a [Login]. Prefixed with `L`. -#[derive(Debug, sqlx::Type)] +#[derive(Clone, Debug, sqlx::Type, serde::Serialize)] #[sqlx(transparent)] pub struct Id(BaseId); |
