From 0a05491930fb34ce7c93c33ea0b7599360483fc7 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Fri, 20 Sep 2024 23:01:18 -0400 Subject: Push events into a module structure consistent with the rest of the project. --- src/app.rs | 5 +- src/channel/app.rs | 111 ++--------------------------------- src/channel/mod.rs | 1 - src/channel/repo/broadcast.rs | 121 -------------------------------------- src/channel/repo/mod.rs | 1 - src/events.rs | 132 ----------------------------------------- src/events/app.rs | 111 +++++++++++++++++++++++++++++++++++ src/events/mod.rs | 5 ++ src/events/repo/broadcast.rs | 121 ++++++++++++++++++++++++++++++++++++++ src/events/repo/mod.rs | 1 + src/events/routes.rs | 133 ++++++++++++++++++++++++++++++++++++++++++ 11 files changed, 376 insertions(+), 366 deletions(-) delete mode 100644 src/channel/repo/broadcast.rs delete mode 100644 src/channel/repo/mod.rs delete mode 100644 src/events.rs create mode 100644 src/events/app.rs create mode 100644 src/events/mod.rs create mode 100644 src/events/repo/broadcast.rs create mode 100644 src/events/repo/mod.rs create mode 100644 src/events/routes.rs (limited to 'src') diff --git a/src/app.rs b/src/app.rs index 0823a0c..e448436 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,9 +1,6 @@ use sqlx::sqlite::SqlitePool; -use crate::{ - channel::app::{Broadcaster, Channels}, - login::app::Logins, -}; +use crate::{channel::app::Channels, events::app::Broadcaster, login::app::Logins}; #[derive(Clone)] pub struct App { diff --git a/src/channel/app.rs b/src/channel/app.rs index 8ae0c3c..f9a75d7 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -1,6 +1,3 @@ -use std::collections::{hash_map::Entry, HashMap}; -use std::sync::{Arc, Mutex, MutexGuard}; - use chrono::TimeDelta; use futures::{ future, @@ -8,12 +5,13 @@ use futures::{ Stream, }; use sqlx::sqlite::SqlitePool; -use tokio::sync::broadcast::{channel, Sender}; -use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; -use super::repo::broadcast::{self, Provider as _}; use crate::{ clock::DateTime, + events::{ + app::Broadcaster, + repo::broadcast::{self, Provider as _}, + }, repo::{ channel::{self, Channel, Provider as _}, error::NotFound as _, @@ -158,104 +156,3 @@ pub enum EventsError { #[error(transparent)] DatabaseError(#[from] sqlx::Error), } - -// Clones will share the same senders collection. -#[derive(Clone)] -pub struct Broadcaster { - // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's - // own advice: . Methods that - // lock it must be sync. - senders: Arc>>>, -} - -impl Broadcaster { - pub async fn from_database(db: &SqlitePool) -> Result { - let mut tx = db.begin().await?; - let channels = tx.channels().all().await?; - tx.commit().await?; - - let channels = channels.iter().map(|c| &c.id); - let broadcaster = Self::new(channels); - Ok(broadcaster) - } - - fn new<'i>(channels: impl IntoIterator) -> Self { - let senders: HashMap<_, _> = channels - .into_iter() - .cloned() - .map(|id| (id, Self::make_sender())) - .collect(); - - Self { - senders: Arc::new(Mutex::new(senders)), - } - } - - // panic: if ``channel`` is already registered. - pub fn register_channel(&self, channel: &channel::Id) { - match self.senders().entry(channel.clone()) { - // This ever happening indicates a serious logic error. - Entry::Occupied(_) => panic!("duplicate channel registration for channel {channel}"), - Entry::Vacant(entry) => { - entry.insert(Self::make_sender()); - } - } - } - - // panic: if ``channel`` has not been previously registered, and was not - // part of the initial set of channels. - pub fn broadcast(&self, channel: &channel::Id, message: broadcast::Message) { - let tx = self.sender(channel); - - // Per the Tokio docs, the returned error is only used to indicate that - // there are no receivers. In this use case, that's fine; a lack of - // listening consumers (chat clients) when a message is sent isn't an - // error. - // - // The successful return value, which includes the number of active - // receivers, also isn't that interesting to us. - let _ = tx.send(message); - } - - // panic: if ``channel`` has not been previously registered, and was not - // part of the initial set of channels. - pub fn listen(&self, channel: &channel::Id) -> impl Stream { - let rx = self.sender(channel).subscribe(); - - BroadcastStream::from(rx) - .take_while(|r| { - future::ready(match r { - Ok(_) => true, - // Stop the stream here. This will disconnect SSE clients - // (see `routes.rs`), who will then resume from - // `Last-Event-ID`, allowing them to catch up by reading - // the skipped messages from the database. - Err(BroadcastStreamRecvError::Lagged(_)) => false, - }) - }) - .map(|r| { - // Since the previous transform stops at the first error, this - // should always hold. - // - // See also . - r.expect("after filtering, only `Ok` messages should remain") - }) - } - - // panic: if ``channel`` has not been previously registered, and was not - // part of the initial set of channels. - fn sender(&self, channel: &channel::Id) -> Sender { - self.senders()[channel].clone() - } - - fn senders(&self) -> MutexGuard>> { - self.senders.lock().unwrap() // propagate panics when mutex is poisoned - } - - fn make_sender() -> Sender { - // Queue depth of 16 chosen entirely arbitrarily. Don't read too much - // into it. - let (tx, _) = channel(16); - tx - } -} diff --git a/src/channel/mod.rs b/src/channel/mod.rs index f67ea04..9f79dbb 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,5 +1,4 @@ pub mod app; -pub mod repo; mod routes; pub use self::routes::router; diff --git a/src/channel/repo/broadcast.rs b/src/channel/repo/broadcast.rs deleted file mode 100644 index 182203a..0000000 --- a/src/channel/repo/broadcast.rs +++ /dev/null @@ -1,121 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - clock::DateTime, - repo::{ - channel::Channel, - login::{self, Login}, - message, - }, -}; - -pub trait Provider { - fn broadcast(&mut self) -> Broadcast; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn broadcast(&mut self) -> Broadcast { - Broadcast(self) - } -} - -pub struct Broadcast<'t>(&'t mut SqliteConnection); - -#[derive(Clone, Debug, serde::Serialize)] -pub struct Message { - pub id: message::Id, - pub sender: Login, - pub body: String, - pub sent_at: DateTime, -} - -impl<'c> Broadcast<'c> { - pub async fn create( - &mut self, - sender: &Login, - channel: &Channel, - body: &str, - sent_at: &DateTime, - ) -> Result { - let id = message::Id::generate(); - - let message = sqlx::query!( - r#" - insert into message - (id, sender, channel, body, sent_at) - values ($1, $2, $3, $4, $5) - returning - id as "id: message::Id", - sender as "sender: login::Id", - body, - sent_at as "sent_at: DateTime" - "#, - id, - sender.id, - channel.id, - body, - sent_at, - ) - .map(|row| Message { - id: row.id, - sender: sender.clone(), - body: row.body, - sent_at: row.sent_at, - }) - .fetch_one(&mut *self.0) - .await?; - - Ok(message) - } - - pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> { - sqlx::query!( - r#" - delete from message - where sent_at < $1 - "#, - expire_at, - ) - .execute(&mut *self.0) - .await?; - - Ok(()) - } - - pub async fn replay( - &mut self, - channel: &Channel, - resume_at: Option<&DateTime>, - ) -> Result, sqlx::Error> { - let messages = sqlx::query!( - r#" - select - message.id as "id: message::Id", - login.id as "sender_id: login::Id", - login.name as sender_name, - message.body, - message.sent_at as "sent_at: DateTime" - from message - join login on message.sender = login.id - where channel = $1 - and coalesce(sent_at > $2, true) - order by sent_at asc - "#, - channel.id, - resume_at, - ) - .map(|row| Message { - id: row.id, - sender: Login { - id: row.sender_id, - name: row.sender_name, - }, - body: row.body, - sent_at: row.sent_at, - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(messages) - } -} diff --git a/src/channel/repo/mod.rs b/src/channel/repo/mod.rs deleted file mode 100644 index 2ed3062..0000000 --- a/src/channel/repo/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod broadcast; diff --git a/src/events.rs b/src/events.rs deleted file mode 100644 index 9cbb0a3..0000000 --- a/src/events.rs +++ /dev/null @@ -1,132 +0,0 @@ -use axum::{ - extract::State, - http::StatusCode, - response::{ - sse::{self, Sse}, - IntoResponse, Response, - }, - routing::get, - Router, -}; -use axum_extra::extract::Query; -use chrono::{self, format::SecondsFormat, DateTime}; -use futures::stream::{self, Stream, StreamExt as _, TryStreamExt as _}; - -use crate::{ - app::App, - channel::{app::EventsError, repo::broadcast}, - clock::RequestedAt, - error::InternalError, - header::LastEventId, - repo::{channel, login::Login}, -}; - -pub fn router() -> Router { - Router::new().route("/api/events", get(on_events)) -} - -#[derive(serde::Deserialize)] -struct EventsQuery { - #[serde(default, rename = "channel")] - channels: Vec, -} - -async fn on_events( - State(app): State, - RequestedAt(now): RequestedAt, - _: Login, // requires auth, but doesn't actually care who you are - last_event_id: Option, - Query(query): Query, -) -> Result>>, ErrorResponse> { - let resume_at = last_event_id - .map(|LastEventId(header)| header) - .map(|header| DateTime::parse_from_rfc3339(&header)) - .transpose() - // impl From would take more code; this is used once. - .map_err(ErrorResponse::LastEventIdError)? - .map(|ts| ts.to_utc()); - - let streams = stream::iter(query.channels) - .then(|channel| { - let app = app.clone(); - async move { - let events = app - .channels() - .events(&channel, &now, resume_at.as_ref()) - .await? - .map(ChannelEvent::wrap(channel)); - - Ok::<_, EventsError>(events) - } - }) - .try_collect::>() - .await - // impl From would take more code; this is used once. - .map_err(ErrorResponse::EventsError)?; - - let stream = stream::select_all(streams); - - Ok(Events(stream)) -} - -struct Events(S); - -impl IntoResponse for Events -where - S: Stream> + Send + 'static, -{ - fn into_response(self) -> Response { - let Self(stream) = self; - let stream = stream.map(to_sse_event); - Sse::new(stream) - .keep_alive(sse::KeepAlive::default()) - .into_response() - } -} - -enum ErrorResponse { - EventsError(EventsError), - LastEventIdError(chrono::ParseError), -} - -impl IntoResponse for ErrorResponse { - fn into_response(self) -> Response { - match self { - Self::EventsError(not_found @ EventsError::ChannelNotFound(_)) => { - (StatusCode::NOT_FOUND, not_found.to_string()).into_response() - } - Self::EventsError(other) => InternalError::from(other).into_response(), - Self::LastEventIdError(other) => { - (StatusCode::BAD_REQUEST, other.to_string()).into_response() - } - } - } -} - -fn to_sse_event(event: ChannelEvent) -> Result { - let data = serde_json::to_string_pretty(&event)?; - let event = sse::Event::default() - .id(event - .message - .sent_at - .to_rfc3339_opts(SecondsFormat::AutoSi, /* use_z */ true)) - .data(&data); - - Ok(event) -} - -#[derive(serde::Serialize)] -struct ChannelEvent { - channel: channel::Id, - #[serde(flatten)] - message: M, -} - -impl ChannelEvent { - fn wrap(channel: channel::Id) -> impl Fn(M) -> Self { - move |message| Self { - channel: channel.clone(), - message, - } - } -} diff --git a/src/events/app.rs b/src/events/app.rs new file mode 100644 index 0000000..dfb23d7 --- /dev/null +++ b/src/events/app.rs @@ -0,0 +1,111 @@ +use std::collections::{hash_map::Entry, HashMap}; +use std::sync::{Arc, Mutex, MutexGuard}; + +use futures::{future, stream::StreamExt as _, Stream}; +use sqlx::sqlite::SqlitePool; +use tokio::sync::broadcast::{channel, Sender}; +use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; + +use super::repo::broadcast; +use crate::repo::channel::{self, Provider as _}; + +// Clones will share the same senders collection. +#[derive(Clone)] +pub struct Broadcaster { + // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's + // own advice: . Methods that + // lock it must be sync. + senders: Arc>>>, +} + +impl Broadcaster { + pub async fn from_database(db: &SqlitePool) -> Result { + let mut tx = db.begin().await?; + let channels = tx.channels().all().await?; + tx.commit().await?; + + let channels = channels.iter().map(|c| &c.id); + let broadcaster = Self::new(channels); + Ok(broadcaster) + } + + fn new<'i>(channels: impl IntoIterator) -> Self { + let senders: HashMap<_, _> = channels + .into_iter() + .cloned() + .map(|id| (id, Self::make_sender())) + .collect(); + + Self { + senders: Arc::new(Mutex::new(senders)), + } + } + + // panic: if ``channel`` is already registered. + pub fn register_channel(&self, channel: &channel::Id) { + match self.senders().entry(channel.clone()) { + // This ever happening indicates a serious logic error. + Entry::Occupied(_) => panic!("duplicate channel registration for channel {channel}"), + Entry::Vacant(entry) => { + entry.insert(Self::make_sender()); + } + } + } + + // panic: if ``channel`` has not been previously registered, and was not + // part of the initial set of channels. + pub fn broadcast(&self, channel: &channel::Id, message: broadcast::Message) { + let tx = self.sender(channel); + + // Per the Tokio docs, the returned error is only used to indicate that + // there are no receivers. In this use case, that's fine; a lack of + // listening consumers (chat clients) when a message is sent isn't an + // error. + // + // The successful return value, which includes the number of active + // receivers, also isn't that interesting to us. + let _ = tx.send(message); + } + + // panic: if ``channel`` has not been previously registered, and was not + // part of the initial set of channels. + pub fn listen(&self, channel: &channel::Id) -> impl Stream { + let rx = self.sender(channel).subscribe(); + + BroadcastStream::from(rx) + .take_while(|r| { + future::ready(match r { + Ok(_) => true, + // Stop the stream here. This will disconnect SSE clients + // (see `routes.rs`), who will then resume from + // `Last-Event-ID`, allowing them to catch up by reading + // the skipped messages from the database. + Err(BroadcastStreamRecvError::Lagged(_)) => false, + }) + }) + .map(|r| { + // Since the previous transform stops at the first error, this + // should always hold. + // + // See also . + r.expect("after filtering, only `Ok` messages should remain") + }) + } + + // panic: if ``channel`` has not been previously registered, and was not + // part of the initial set of channels. + fn sender(&self, channel: &channel::Id) -> Sender { + self.senders()[channel].clone() + } + + fn senders(&self) -> MutexGuard>> { + self.senders.lock().unwrap() // propagate panics when mutex is poisoned + } + + fn make_sender() -> Sender { + // Queue depth of 16 chosen entirely arbitrarily. Don't read too much + // into it. + let (tx, _) = channel(16); + tx + } +} diff --git a/src/events/mod.rs b/src/events/mod.rs new file mode 100644 index 0000000..f67ea04 --- /dev/null +++ b/src/events/mod.rs @@ -0,0 +1,5 @@ +pub mod app; +pub mod repo; +mod routes; + +pub use self::routes::router; diff --git a/src/events/repo/broadcast.rs b/src/events/repo/broadcast.rs new file mode 100644 index 0000000..182203a --- /dev/null +++ b/src/events/repo/broadcast.rs @@ -0,0 +1,121 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + clock::DateTime, + repo::{ + channel::Channel, + login::{self, Login}, + message, + }, +}; + +pub trait Provider { + fn broadcast(&mut self) -> Broadcast; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn broadcast(&mut self) -> Broadcast { + Broadcast(self) + } +} + +pub struct Broadcast<'t>(&'t mut SqliteConnection); + +#[derive(Clone, Debug, serde::Serialize)] +pub struct Message { + pub id: message::Id, + pub sender: Login, + pub body: String, + pub sent_at: DateTime, +} + +impl<'c> Broadcast<'c> { + pub async fn create( + &mut self, + sender: &Login, + channel: &Channel, + body: &str, + sent_at: &DateTime, + ) -> Result { + let id = message::Id::generate(); + + let message = sqlx::query!( + r#" + insert into message + (id, sender, channel, body, sent_at) + values ($1, $2, $3, $4, $5) + returning + id as "id: message::Id", + sender as "sender: login::Id", + body, + sent_at as "sent_at: DateTime" + "#, + id, + sender.id, + channel.id, + body, + sent_at, + ) + .map(|row| Message { + id: row.id, + sender: sender.clone(), + body: row.body, + sent_at: row.sent_at, + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> { + sqlx::query!( + r#" + delete from message + where sent_at < $1 + "#, + expire_at, + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } + + pub async fn replay( + &mut self, + channel: &Channel, + resume_at: Option<&DateTime>, + ) -> Result, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + message.id as "id: message::Id", + login.id as "sender_id: login::Id", + login.name as sender_name, + message.body, + message.sent_at as "sent_at: DateTime" + from message + join login on message.sender = login.id + where channel = $1 + and coalesce(sent_at > $2, true) + order by sent_at asc + "#, + channel.id, + resume_at, + ) + .map(|row| Message { + id: row.id, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + body: row.body, + sent_at: row.sent_at, + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } +} diff --git a/src/events/repo/mod.rs b/src/events/repo/mod.rs new file mode 100644 index 0000000..2ed3062 --- /dev/null +++ b/src/events/repo/mod.rs @@ -0,0 +1 @@ +pub mod broadcast; diff --git a/src/events/routes.rs b/src/events/routes.rs new file mode 100644 index 0000000..f880c70 --- /dev/null +++ b/src/events/routes.rs @@ -0,0 +1,133 @@ +use axum::{ + extract::State, + http::StatusCode, + response::{ + sse::{self, Sse}, + IntoResponse, Response, + }, + routing::get, + Router, +}; +use axum_extra::extract::Query; +use chrono::{self, format::SecondsFormat, DateTime}; +use futures::stream::{self, Stream, StreamExt as _, TryStreamExt as _}; + +use super::repo::broadcast; +use crate::{ + app::App, + channel::app::EventsError, + clock::RequestedAt, + error::InternalError, + header::LastEventId, + repo::{channel, login::Login}, +}; + +pub fn router() -> Router { + Router::new().route("/api/events", get(on_events)) +} + +#[derive(serde::Deserialize)] +struct EventsQuery { + #[serde(default, rename = "channel")] + channels: Vec, +} + +async fn on_events( + State(app): State, + RequestedAt(now): RequestedAt, + _: Login, // requires auth, but doesn't actually care who you are + last_event_id: Option, + Query(query): Query, +) -> Result>>, ErrorResponse> { + let resume_at = last_event_id + .map(|LastEventId(header)| header) + .map(|header| DateTime::parse_from_rfc3339(&header)) + .transpose() + // impl From would take more code; this is used once. + .map_err(ErrorResponse::LastEventIdError)? + .map(|ts| ts.to_utc()); + + let streams = stream::iter(query.channels) + .then(|channel| { + let app = app.clone(); + async move { + let events = app + .channels() + .events(&channel, &now, resume_at.as_ref()) + .await? + .map(ChannelEvent::wrap(channel)); + + Ok::<_, EventsError>(events) + } + }) + .try_collect::>() + .await + // impl From would take more code; this is used once. + .map_err(ErrorResponse::EventsError)?; + + let stream = stream::select_all(streams); + + Ok(Events(stream)) +} + +struct Events(S); + +impl IntoResponse for Events +where + S: Stream> + Send + 'static, +{ + fn into_response(self) -> Response { + let Self(stream) = self; + let stream = stream.map(to_sse_event); + Sse::new(stream) + .keep_alive(sse::KeepAlive::default()) + .into_response() + } +} + +enum ErrorResponse { + EventsError(EventsError), + LastEventIdError(chrono::ParseError), +} + +impl IntoResponse for ErrorResponse { + fn into_response(self) -> Response { + match self { + Self::EventsError(not_found @ EventsError::ChannelNotFound(_)) => { + (StatusCode::NOT_FOUND, not_found.to_string()).into_response() + } + Self::EventsError(other) => InternalError::from(other).into_response(), + Self::LastEventIdError(other) => { + (StatusCode::BAD_REQUEST, other.to_string()).into_response() + } + } + } +} + +fn to_sse_event(event: ChannelEvent) -> Result { + let data = serde_json::to_string_pretty(&event)?; + let event = sse::Event::default() + .id(event + .message + .sent_at + .to_rfc3339_opts(SecondsFormat::AutoSi, /* use_z */ true)) + .data(&data); + + Ok(event) +} + +#[derive(serde::Serialize)] +struct ChannelEvent { + channel: channel::Id, + #[serde(flatten)] + message: M, +} + +impl ChannelEvent { + fn wrap(channel: channel::Id) -> impl Fn(M) -> Self { + move |message| Self { + channel: channel.clone(), + message, + } + } +} -- cgit v1.2.3