summaryrefslogtreecommitdiff
path: root/src/channel
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2024-09-13 22:30:02 -0400
committerOwen Jacobson <owen@grimoire.ca>2024-09-13 23:12:31 -0400
commit407ca8df6284ce1a4c649b018c7326fd195bbd26 (patch)
tree876091c17efbd765a4c7ef339548c0ff4dfb96d5 /src/channel
parent388a3d5a925aef7ff39339454ae0d720e05f038e (diff)
Support Last-Event-Id as a method of resuming channel events after a disconnect
Diffstat (limited to 'src/channel')
-rw-r--r--src/channel/app.rs30
-rw-r--r--src/channel/header.rs34
-rw-r--r--src/channel/mod.rs1
-rw-r--r--src/channel/repo/messages.rs3
-rw-r--r--src/channel/routes.rs31
5 files changed, 90 insertions, 9 deletions
diff --git a/src/channel/app.rs b/src/channel/app.rs
index e242c2f..c0a6d60 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -2,8 +2,9 @@ use std::collections::{hash_map::Entry, HashMap};
use std::sync::{Arc, Mutex, MutexGuard};
use futures::{
+ future,
stream::{self, StreamExt as _, TryStreamExt as _},
- Stream,
+ TryStream,
};
use sqlx::sqlite::SqlitePool;
use tokio::sync::broadcast::{channel, Sender};
@@ -55,14 +56,33 @@ impl<'a> Channels<'a> {
pub async fn events(
&self,
channel: &ChannelId,
- ) -> Result<impl Stream<Item = Result<BroadcastMessage, BoxedError>>, BoxedError> {
- let live_messages = self.broadcaster.listen(channel).map_err(BoxedError::from);
+ resume_at: Option<&DateTime>,
+ ) -> Result<impl TryStream<Ok = BroadcastMessage, Error = BoxedError>, BoxedError> {
+ fn skip_stale<E>(
+ resume_at: Option<&DateTime>,
+ ) -> impl for<'m> FnMut(&'m BroadcastMessage) -> future::Ready<Result<bool, E>> {
+ let resume_at = resume_at.cloned();
+ move |msg| {
+ future::ready(Ok(match resume_at {
+ None => false,
+ Some(resume_at) => msg.sent_at <= resume_at,
+ }))
+ }
+ }
+
+ let live_messages = self
+ .broadcaster
+ .listen(channel)
+ .map_err(BoxedError::from)
+ .try_skip_while(skip_stale(resume_at));
let mut tx = self.db.begin().await?;
- let stored_messages = tx.messages().for_replay(channel).await?;
+ let stored_messages = tx.messages().for_replay(channel, resume_at).await?;
tx.commit().await?;
- Ok(stream::iter(stored_messages).map(Ok).chain(live_messages))
+ let stored_messages = stream::iter(stored_messages).map(Ok);
+
+ Ok(stored_messages.chain(live_messages))
}
}
diff --git a/src/channel/header.rs b/src/channel/header.rs
new file mode 100644
index 0000000..eda8214
--- /dev/null
+++ b/src/channel/header.rs
@@ -0,0 +1,34 @@
+use axum::http::{HeaderName, HeaderValue};
+
+pub struct LastEventId(pub String);
+
+static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id");
+
+impl headers::Header for LastEventId {
+ fn name() -> &'static HeaderName {
+ &LAST_EVENT_ID
+ }
+
+ fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
+ where
+ I: Iterator<Item = &'i HeaderValue>,
+ {
+ let value = values.next().ok_or_else(headers::Error::invalid)?;
+ if let Ok(value) = value.to_str() {
+ Ok(Self(value.into()))
+ } else {
+ Err(headers::Error::invalid())
+ }
+ }
+
+ fn encode<E>(&self, values: &mut E)
+ where
+ E: Extend<HeaderValue>,
+ {
+ let Self(value) = self;
+ // Must panic or suppress; the trait provides no other options.
+ let value = HeaderValue::from_str(value).expect("LastEventId is a valid header value");
+
+ values.extend(std::iter::once(value));
+ }
+}
diff --git a/src/channel/mod.rs b/src/channel/mod.rs
index f67ea04..bc2cc6c 100644
--- a/src/channel/mod.rs
+++ b/src/channel/mod.rs
@@ -1,4 +1,5 @@
pub mod app;
+mod header;
pub mod repo;
mod routes;
diff --git a/src/channel/repo/messages.rs b/src/channel/repo/messages.rs
index fe833b6..b465f61 100644
--- a/src/channel/repo/messages.rs
+++ b/src/channel/repo/messages.rs
@@ -73,6 +73,7 @@ impl<'c> Messages<'c> {
pub async fn for_replay(
&mut self,
channel: &ChannelId,
+ resume_at: Option<&DateTime>,
) -> Result<Vec<BroadcastMessage>, BoxedError> {
let messages = sqlx::query!(
r#"
@@ -85,9 +86,11 @@ impl<'c> Messages<'c> {
from message
join login on message.sender = login.id
where channel = $1
+ and coalesce(sent_at > $2, true)
order by sent_at asc
"#,
channel,
+ resume_at,
)
.map(|row| BroadcastMessage {
id: row.id,
diff --git a/src/channel/routes.rs b/src/channel/routes.rs
index 0f95c69..4f83a8b 100644
--- a/src/channel/routes.rs
+++ b/src/channel/routes.rs
@@ -8,9 +8,14 @@ use axum::{
routing::{get, post},
Router,
};
+use axum_extra::TypedHeader;
+use chrono::{format::SecondsFormat, DateTime};
use futures::{future, stream::TryStreamExt as _};
-use super::repo::channels::Id as ChannelId;
+use super::{
+ header::LastEventId,
+ repo::{channels::Id as ChannelId, messages::BroadcastMessage},
+};
use crate::{
app::App, clock::RequestedAt, error::BoxedError, error::InternalError,
login::repo::logins::Login,
@@ -61,13 +66,31 @@ async fn on_events(
Path(channel): Path<ChannelId>,
State(app): State<App>,
_: Login, // requires auth, but doesn't actually care who you are
+ last_event_id: Option<TypedHeader<LastEventId>>,
) -> Result<impl IntoResponse, InternalError> {
+ let resume_at = last_event_id
+ .map(|TypedHeader(header)| header)
+ .map(|LastEventId(header)| header)
+ .map(|header| DateTime::parse_from_rfc3339(&header))
+ .transpose()?
+ .map(|ts| ts.to_utc());
+
let stream = app
.channels()
- .events(&channel)
+ .events(&channel, resume_at.as_ref())
.await?
- .and_then(|msg| future::ready(serde_json::to_string(&msg).map_err(BoxedError::from)))
- .map_ok(|msg| sse::Event::default().data(&msg));
+ .and_then(|msg| future::ready(to_event(msg)));
Ok(Sse::new(stream).keep_alive(sse::KeepAlive::default()))
}
+
+fn to_event(msg: BroadcastMessage) -> Result<sse::Event, BoxedError> {
+ let data = serde_json::to_string(&msg)?;
+ let event = sse::Event::default()
+ .id(msg
+ .sent_at
+ .to_rfc3339_opts(SecondsFormat::AutoSi, /* use_z */ true))
+ .data(&data);
+
+ Ok(event)
+}