diff options
| author | Owen Jacobson <owen@grimoire.ca> | 2024-09-13 22:30:02 -0400 |
|---|---|---|
| committer | Owen Jacobson <owen@grimoire.ca> | 2024-09-13 23:12:31 -0400 |
| commit | 407ca8df6284ce1a4c649b018c7326fd195bbd26 (patch) | |
| tree | 876091c17efbd765a4c7ef339548c0ff4dfb96d5 /src/channel | |
| parent | 388a3d5a925aef7ff39339454ae0d720e05f038e (diff) | |
Support Last-Event-Id as a method of resuming channel events after a disconnect
Diffstat (limited to 'src/channel')
| -rw-r--r-- | src/channel/app.rs | 30 | ||||
| -rw-r--r-- | src/channel/header.rs | 34 | ||||
| -rw-r--r-- | src/channel/mod.rs | 1 | ||||
| -rw-r--r-- | src/channel/repo/messages.rs | 3 | ||||
| -rw-r--r-- | src/channel/routes.rs | 31 |
5 files changed, 90 insertions, 9 deletions
diff --git a/src/channel/app.rs b/src/channel/app.rs index e242c2f..c0a6d60 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -2,8 +2,9 @@ use std::collections::{hash_map::Entry, HashMap}; use std::sync::{Arc, Mutex, MutexGuard}; use futures::{ + future, stream::{self, StreamExt as _, TryStreamExt as _}, - Stream, + TryStream, }; use sqlx::sqlite::SqlitePool; use tokio::sync::broadcast::{channel, Sender}; @@ -55,14 +56,33 @@ impl<'a> Channels<'a> { pub async fn events( &self, channel: &ChannelId, - ) -> Result<impl Stream<Item = Result<BroadcastMessage, BoxedError>>, BoxedError> { - let live_messages = self.broadcaster.listen(channel).map_err(BoxedError::from); + resume_at: Option<&DateTime>, + ) -> Result<impl TryStream<Ok = BroadcastMessage, Error = BoxedError>, BoxedError> { + fn skip_stale<E>( + resume_at: Option<&DateTime>, + ) -> impl for<'m> FnMut(&'m BroadcastMessage) -> future::Ready<Result<bool, E>> { + let resume_at = resume_at.cloned(); + move |msg| { + future::ready(Ok(match resume_at { + None => false, + Some(resume_at) => msg.sent_at <= resume_at, + })) + } + } + + let live_messages = self + .broadcaster + .listen(channel) + .map_err(BoxedError::from) + .try_skip_while(skip_stale(resume_at)); let mut tx = self.db.begin().await?; - let stored_messages = tx.messages().for_replay(channel).await?; + let stored_messages = tx.messages().for_replay(channel, resume_at).await?; tx.commit().await?; - Ok(stream::iter(stored_messages).map(Ok).chain(live_messages)) + let stored_messages = stream::iter(stored_messages).map(Ok); + + Ok(stored_messages.chain(live_messages)) } } diff --git a/src/channel/header.rs b/src/channel/header.rs new file mode 100644 index 0000000..eda8214 --- /dev/null +++ b/src/channel/header.rs @@ -0,0 +1,34 @@ +use axum::http::{HeaderName, HeaderValue}; + +pub struct LastEventId(pub String); + +static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id"); + +impl headers::Header for LastEventId { + fn name() -> &'static HeaderName { + &LAST_EVENT_ID + } + + fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error> + where + I: Iterator<Item = &'i HeaderValue>, + { + let value = values.next().ok_or_else(headers::Error::invalid)?; + if let Ok(value) = value.to_str() { + Ok(Self(value.into())) + } else { + Err(headers::Error::invalid()) + } + } + + fn encode<E>(&self, values: &mut E) + where + E: Extend<HeaderValue>, + { + let Self(value) = self; + // Must panic or suppress; the trait provides no other options. + let value = HeaderValue::from_str(value).expect("LastEventId is a valid header value"); + + values.extend(std::iter::once(value)); + } +} diff --git a/src/channel/mod.rs b/src/channel/mod.rs index f67ea04..bc2cc6c 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,4 +1,5 @@ pub mod app; +mod header; pub mod repo; mod routes; diff --git a/src/channel/repo/messages.rs b/src/channel/repo/messages.rs index fe833b6..b465f61 100644 --- a/src/channel/repo/messages.rs +++ b/src/channel/repo/messages.rs @@ -73,6 +73,7 @@ impl<'c> Messages<'c> { pub async fn for_replay( &mut self, channel: &ChannelId, + resume_at: Option<&DateTime>, ) -> Result<Vec<BroadcastMessage>, BoxedError> { let messages = sqlx::query!( r#" @@ -85,9 +86,11 @@ impl<'c> Messages<'c> { from message join login on message.sender = login.id where channel = $1 + and coalesce(sent_at > $2, true) order by sent_at asc "#, channel, + resume_at, ) .map(|row| BroadcastMessage { id: row.id, diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 0f95c69..4f83a8b 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -8,9 +8,14 @@ use axum::{ routing::{get, post}, Router, }; +use axum_extra::TypedHeader; +use chrono::{format::SecondsFormat, DateTime}; use futures::{future, stream::TryStreamExt as _}; -use super::repo::channels::Id as ChannelId; +use super::{ + header::LastEventId, + repo::{channels::Id as ChannelId, messages::BroadcastMessage}, +}; use crate::{ app::App, clock::RequestedAt, error::BoxedError, error::InternalError, login::repo::logins::Login, @@ -61,13 +66,31 @@ async fn on_events( Path(channel): Path<ChannelId>, State(app): State<App>, _: Login, // requires auth, but doesn't actually care who you are + last_event_id: Option<TypedHeader<LastEventId>>, ) -> Result<impl IntoResponse, InternalError> { + let resume_at = last_event_id + .map(|TypedHeader(header)| header) + .map(|LastEventId(header)| header) + .map(|header| DateTime::parse_from_rfc3339(&header)) + .transpose()? + .map(|ts| ts.to_utc()); + let stream = app .channels() - .events(&channel) + .events(&channel, resume_at.as_ref()) .await? - .and_then(|msg| future::ready(serde_json::to_string(&msg).map_err(BoxedError::from))) - .map_ok(|msg| sse::Event::default().data(&msg)); + .and_then(|msg| future::ready(to_event(msg))); Ok(Sse::new(stream).keep_alive(sse::KeepAlive::default())) } + +fn to_event(msg: BroadcastMessage) -> Result<sse::Event, BoxedError> { + let data = serde_json::to_string(&msg)?; + let event = sse::Event::default() + .id(msg + .sent_at + .to_rfc3339_opts(SecondsFormat::AutoSi, /* use_z */ true)) + .data(&data); + + Ok(event) +} |
