summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2024-09-27 18:17:02 -0400
committerOwen Jacobson <owen@grimoire.ca>2024-09-27 19:59:22 -0400
commiteff129bc1f29bcb1b2b9d10c6b49ab886edc83d6 (patch)
treeb82892a6cf40f771998a85e5530012bab80157dc /src
parent68e3dce3c2e588376c6510783e908941360ac80e (diff)
Make `/api/events` a firehose endpoint.
It now includes events for all channels. Clients are responsible for filtering. The schema for channel events has changed; it now includes a channel name and ID, in the same format as the sender's name and ID. They also now include a `"type"` field, whose only valid value (as of this writing) is `"message"`. This is groundwork for delivering message deletion (expiry) events to clients, and notifying clients of channel lifecycle events.
Diffstat (limited to 'src')
-rw-r--r--src/app.rs8
-rw-r--r--src/channel/app.rs11
-rw-r--r--src/channel/routes/test/on_send.rs89
-rw-r--r--src/cli.rs2
-rw-r--r--src/events/app.rs93
-rw-r--r--src/events/broadcaster.rs77
-rw-r--r--src/events/mod.rs1
-rw-r--r--src/events/repo/message.rs (renamed from src/events/repo/broadcast.rs)83
-rw-r--r--src/events/repo/mod.rs2
-rw-r--r--src/events/routes.rs124
-rw-r--r--src/events/routes/test.rs276
-rw-r--r--src/events/types.rs99
-rw-r--r--src/repo/channel.rs2
-rw-r--r--src/test/fixtures/message.rs4
-rw-r--r--src/test/fixtures/mod.rs2
15 files changed, 309 insertions, 564 deletions
diff --git a/src/app.rs b/src/app.rs
index b2f861c..07b932a 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -13,9 +13,9 @@ pub struct App {
}
impl App {
- pub async fn from(db: SqlitePool) -> Result<Self, sqlx::Error> {
- let broadcaster = Broadcaster::from_database(&db).await?;
- Ok(Self { db, broadcaster })
+ pub fn from(db: SqlitePool) -> Self {
+ let broadcaster = Broadcaster::default();
+ Self { db, broadcaster }
}
}
@@ -29,6 +29,6 @@ impl App {
}
pub const fn channels(&self) -> Channels {
- Channels::new(&self.db, &self.broadcaster)
+ Channels::new(&self.db)
}
}
diff --git a/src/channel/app.rs b/src/channel/app.rs
index 793fa35..6bad158 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -1,18 +1,14 @@
use sqlx::sqlite::SqlitePool;
-use crate::{
- events::broadcaster::Broadcaster,
- repo::channel::{Channel, Provider as _},
-};
+use crate::repo::channel::{Channel, Provider as _};
pub struct Channels<'a> {
db: &'a SqlitePool,
- broadcaster: &'a Broadcaster,
}
impl<'a> Channels<'a> {
- pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self {
- Self { db, broadcaster }
+ pub const fn new(db: &'a SqlitePool) -> Self {
+ Self { db }
}
pub async fn create(&self, name: &str) -> Result<Channel, CreateError> {
@@ -22,7 +18,6 @@ impl<'a> Channels<'a> {
.create(name)
.await
.map_err(|err| CreateError::from_duplicate_name(err, name))?;
- self.broadcaster.register_channel(&channel.id);
tx.commit().await?;
Ok(channel)
diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs
index 93a5480..5d87bdc 100644
--- a/src/channel/routes/test/on_send.rs
+++ b/src/channel/routes/test/on_send.rs
@@ -1,65 +1,14 @@
-use axum::{
- extract::{Json, Path, State},
- http::StatusCode,
-};
+use axum::extract::{Json, Path, State};
use futures::stream::StreamExt;
use crate::{
channel::routes,
- events::app,
+ events::{app, types},
repo::channel,
test::fixtures::{self, future::Immediately as _},
};
#[tokio::test]
-async fn channel_exists() {
- // Set up the environment
-
- let app = fixtures::scratch_app().await;
- let sender = fixtures::login::create(&app).await;
- let channel = fixtures::channel::create(&app).await;
-
- // Call the endpoint
-
- let sent_at = fixtures::now();
- let request = routes::SendRequest {
- message: fixtures::message::propose(),
- };
- let status = routes::on_send(
- State(app.clone()),
- Path(channel.id.clone()),
- sent_at.clone(),
- sender.clone(),
- Json(request.clone()),
- )
- .await
- .expect("sending to a valid channel");
-
- // Verify the structure of the response
-
- assert_eq!(StatusCode::ACCEPTED, status);
-
- // Verify the semantics
-
- let subscribed_at = fixtures::now();
- let mut events = app
- .events()
- .subscribe(&channel.id, &subscribed_at, None)
- .await
- .expect("subscribing to a valid channel");
-
- let event = events
- .next()
- .immediately()
- .await
- .expect("event received by subscribers");
-
- assert_eq!(request.message, event.body);
- assert_eq!(sender, event.sender);
- assert_eq!(*sent_at, event.sent_at);
-}
-
-#[tokio::test]
async fn messages_in_order() {
// Set up the environment
@@ -70,21 +19,15 @@ async fn messages_in_order() {
// Call the endpoint (twice)
let requests = vec![
- (
- fixtures::now(),
- routes::SendRequest {
- message: fixtures::message::propose(),
- },
- ),
- (
- fixtures::now(),
- routes::SendRequest {
- message: fixtures::message::propose(),
- },
- ),
+ (fixtures::now(), fixtures::message::propose()),
+ (fixtures::now(), fixtures::message::propose()),
];
- for (sent_at, request) in &requests {
+ for (sent_at, message) in &requests {
+ let request = routes::SendRequest {
+ message: message.clone(),
+ };
+
routes::on_send(
State(app.clone()),
Path(channel.id.clone()),
@@ -101,17 +44,21 @@ async fn messages_in_order() {
let subscribed_at = fixtures::now();
let events = app
.events()
- .subscribe(&channel.id, &subscribed_at, None)
+ .subscribe(&subscribed_at, types::ResumePoint::default())
.await
.expect("subscribing to a valid channel")
.take(requests.len());
let events = events.collect::<Vec<_>>().immediately().await;
- for ((sent_at, request), event) in requests.into_iter().zip(events) {
- assert_eq!(request.message, event.body);
- assert_eq!(sender, event.sender);
- assert_eq!(*sent_at, event.sent_at);
+ for ((sent_at, message), types::ResumableEvent(_, event)) in requests.into_iter().zip(events) {
+ assert_eq!(*sent_at, event.at);
+ assert!(matches!(
+ event.data,
+ types::ChannelEventData::Message(event_message)
+ if event_message.sender == sender
+ && event_message.body == message
+ ));
}
}
diff --git a/src/cli.rs b/src/cli.rs
index b147f7d..a6d752c 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -70,7 +70,7 @@ impl Args {
pub async fn run(self) -> Result<(), Error> {
let pool = self.pool().await?;
- let app = App::from(pool).await?;
+ let app = App::from(pool);
let app = routers()
.route_layer(middleware::from_fn(clock::middleware))
.with_state(app);
diff --git a/src/events/app.rs b/src/events/app.rs
index 7229551..043a29b 100644
--- a/src/events/app.rs
+++ b/src/events/app.rs
@@ -1,3 +1,5 @@
+use std::collections::BTreeMap;
+
use chrono::TimeDelta;
use futures::{
future,
@@ -8,7 +10,8 @@ use sqlx::sqlite::SqlitePool;
use super::{
broadcaster::Broadcaster,
- repo::broadcast::{self, Provider as _},
+ repo::message::Provider as _,
+ types::{self, ResumePoint},
};
use crate::{
clock::DateTime,
@@ -35,64 +38,56 @@ impl<'a> Events<'a> {
channel: &channel::Id,
body: &str,
sent_at: &DateTime,
- ) -> Result<broadcast::Message, EventsError> {
+ ) -> Result<types::ChannelEvent, EventsError> {
let mut tx = self.db.begin().await?;
let channel = tx
.channels()
.by_id(channel)
.await
.not_found(|| EventsError::ChannelNotFound(channel.clone()))?;
- let message = tx
- .broadcast()
+ let event = tx
+ .message_events()
.create(login, &channel, body, sent_at)
.await?;
tx.commit().await?;
- self.broadcaster.broadcast(&channel.id, &message);
- Ok(message)
+ self.broadcaster.broadcast(&event);
+ Ok(event)
}
pub async fn subscribe(
&self,
- channel: &channel::Id,
subscribed_at: &DateTime,
- resume_at: Option<broadcast::Sequence>,
- ) -> Result<impl Stream<Item = broadcast::Message> + std::fmt::Debug, EventsError> {
+ resume_at: ResumePoint,
+ ) -> Result<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug, sqlx::Error> {
// Somewhat arbitrarily, expire after 90 days.
let expire_at = subscribed_at.to_owned() - TimeDelta::days(90);
let mut tx = self.db.begin().await?;
- let channel = tx
- .channels()
- .by_id(channel)
- .await
- .not_found(|| EventsError::ChannelNotFound(channel.clone()))?;
+ let channels = tx.channels().all().await?;
// Subscribe before retrieving, to catch messages broadcast while we're
// querying the DB. We'll prune out duplicates later.
- let live_messages = self.broadcaster.subscribe(&channel.id);
+ let live_messages = self.broadcaster.subscribe();
- tx.broadcast().expire(&expire_at).await?;
- let stored_messages = tx.broadcast().replay(&channel, resume_at).await?;
- tx.commit().await?;
+ tx.message_events().expire(&expire_at).await?;
- let resume_broadcast_at = stored_messages
- .last()
- .map(|message| message.sequence)
- .or(resume_at);
+ let mut replays = BTreeMap::new();
+ let mut resume_live_at = resume_at.clone();
+ for channel in channels {
+ let replay = tx
+ .message_events()
+ .replay(&channel, resume_at.get(&channel.id))
+ .await?;
- // This should always be the case, up to integer rollover, primarily
- // because every message in stored_messages has a sequence not less
- // than `resume_at`, or `resume_at` is None. We use the last message
- // (if any) to decide when to resume the `live_messages` stream.
- //
- // It probably simplifies to assert!(resume_at <= resume_broadcast_at), but
- // this form captures more of the reasoning.
- assert!(
- (resume_at.is_none() && resume_broadcast_at.is_none())
- || (stored_messages.is_empty() && resume_at == resume_broadcast_at)
- || resume_at < resume_broadcast_at
- );
+ if let Some(last) = replay.last() {
+ resume_live_at.advance(&channel.id, last.sequence);
+ }
+
+ replays.insert(channel.id.clone(), replay);
+ }
+
+ let replay = stream::select_all(replays.into_values().map(stream::iter));
// no skip_expired or resume transforms for stored_messages, as it's
// constructed not to contain messages meeting either criterion.
@@ -100,7 +95,6 @@ impl<'a> Events<'a> {
// * skip_expired is redundant with the `tx.broadcasts().expire(…)` call;
// * resume is redundant with the resume_at argument to
// `tx.broadcasts().replay(…)`.
- let stored_messages = stream::iter(stored_messages);
let live_messages = live_messages
// Sure, it's temporally improbable that we'll ever skip a message
// that's 90 days old, but there's no reason not to be thorough.
@@ -108,26 +102,31 @@ impl<'a> Events<'a> {
// Filtering on the broadcast resume point filters out messages
// before resume_at, and filters out messages duplicated from
// stored_messages.
- .filter(Self::resume(resume_broadcast_at));
+ .filter(Self::resume(resume_live_at));
- Ok(stored_messages.chain(live_messages))
+ Ok(replay
+ .chain(live_messages)
+ .scan(resume_at, |resume_point, event| {
+ let channel = &event.channel.id;
+ let sequence = event.sequence;
+ resume_point.advance(channel, sequence);
+
+ let event = types::ResumableEvent(resume_point.clone(), event);
+
+ future::ready(Some(event))
+ }))
}
fn resume(
- resume_at: Option<broadcast::Sequence>,
- ) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready<bool> {
- move |msg| {
- future::ready(match resume_at {
- None => true,
- Some(resume_at) => msg.sequence > resume_at,
- })
- }
+ resume_at: ResumePoint,
+ ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> {
+ move |event| future::ready(resume_at < event.sequence())
}
fn skip_expired(
expire_at: &DateTime,
- ) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready<bool> {
+ ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> {
let expire_at = expire_at.to_owned();
- move |msg| future::ready(msg.sent_at > expire_at)
+ move |event| future::ready(expire_at < event.at)
}
}
diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs
index dcaba91..9697c0a 100644
--- a/src/events/broadcaster.rs
+++ b/src/events/broadcaster.rs
@@ -1,63 +1,35 @@
-use std::collections::{hash_map::Entry, HashMap};
-use std::sync::{Arc, Mutex, MutexGuard};
+use std::sync::{Arc, Mutex};
use futures::{future, stream::StreamExt as _, Stream};
-use sqlx::sqlite::SqlitePool;
use tokio::sync::broadcast::{channel, Sender};
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
-use crate::{
- events::repo::broadcast,
- repo::channel::{self, Provider as _},
-};
+use crate::events::types;
-// Clones will share the same senders collection.
+// Clones will share the same sender.
#[derive(Clone)]
pub struct Broadcaster {
// The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
// own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
// lock it must be sync.
- senders: Arc<Mutex<HashMap<channel::Id, Sender<broadcast::Message>>>>,
+ senders: Arc<Mutex<Sender<types::ChannelEvent>>>,
}
-impl Broadcaster {
- pub async fn from_database(db: &SqlitePool) -> Result<Self, sqlx::Error> {
- let mut tx = db.begin().await?;
- let channels = tx.channels().all().await?;
- tx.commit().await?;
-
- let channels = channels.iter().map(|c| &c.id);
- let broadcaster = Self::new(channels);
- Ok(broadcaster)
- }
-
- fn new<'i>(channels: impl IntoIterator<Item = &'i channel::Id>) -> Self {
- let senders: HashMap<_, _> = channels
- .into_iter()
- .cloned()
- .map(|id| (id, Self::make_sender()))
- .collect();
+impl Default for Broadcaster {
+ fn default() -> Self {
+ let sender = Self::make_sender();
Self {
- senders: Arc::new(Mutex::new(senders)),
- }
- }
-
- // panic: if ``channel`` is already registered.
- pub fn register_channel(&self, channel: &channel::Id) {
- match self.senders().entry(channel.clone()) {
- // This ever happening indicates a serious logic error.
- Entry::Occupied(_) => panic!("duplicate channel registration for channel {channel}"),
- Entry::Vacant(entry) => {
- entry.insert(Self::make_sender());
- }
+ senders: Arc::new(Mutex::new(sender)),
}
}
+}
- // panic: if ``channel`` has not been previously registered, and was not
- // part of the initial set of channels.
- pub fn broadcast(&self, channel: &channel::Id, message: &broadcast::Message) {
- let tx = self.sender(channel);
+impl Broadcaster {
+ // panic: if ``message.channel.id`` has not been previously registered,
+ // and was not part of the initial set of channels.
+ pub fn broadcast(&self, message: &types::ChannelEvent) {
+ let tx = self.sender();
// Per the Tokio docs, the returned error is only used to indicate that
// there are no receivers. In this use case, that's fine; a lack of
@@ -71,15 +43,12 @@ impl Broadcaster {
// panic: if ``channel`` has not been previously registered, and was not
// part of the initial set of channels.
- pub fn subscribe(
- &self,
- channel: &channel::Id,
- ) -> impl Stream<Item = broadcast::Message> + std::fmt::Debug {
- let rx = self.sender(channel).subscribe();
+ pub fn subscribe(&self) -> impl Stream<Item = types::ChannelEvent> + std::fmt::Debug {
+ let rx = self.sender().subscribe();
BroadcastStream::from(rx).scan((), |(), r| {
future::ready(match r {
- Ok(message) => Some(message),
+ Ok(event) => Some(event),
// Stop the stream here. This will disconnect SSE clients
// (see `routes.rs`), who will then resume from
// `Last-Event-ID`, allowing them to catch up by reading
@@ -92,17 +61,11 @@ impl Broadcaster {
})
}
- // panic: if ``channel`` has not been previously registered, and was not
- // part of the initial set of channels.
- fn sender(&self, channel: &channel::Id) -> Sender<broadcast::Message> {
- self.senders()[channel].clone()
- }
-
- fn senders(&self) -> MutexGuard<HashMap<channel::Id, Sender<broadcast::Message>>> {
- self.senders.lock().unwrap() // propagate panics when mutex is poisoned
+ fn sender(&self) -> Sender<types::ChannelEvent> {
+ self.senders.lock().unwrap().clone()
}
- fn make_sender() -> Sender<broadcast::Message> {
+ fn make_sender() -> Sender<types::ChannelEvent> {
// Queue depth of 16 chosen entirely arbitrarily. Don't read too much
// into it.
let (tx, _) = channel(16);
diff --git a/src/events/mod.rs b/src/events/mod.rs
index b9f3f5b..711ae64 100644
--- a/src/events/mod.rs
+++ b/src/events/mod.rs
@@ -3,5 +3,6 @@ pub mod broadcaster;
mod extract;
pub mod repo;
mod routes;
+pub mod types;
pub use self::routes::router;
diff --git a/src/events/repo/broadcast.rs b/src/events/repo/message.rs
index 6914573..b4724ea 100644
--- a/src/events/repo/broadcast.rs
+++ b/src/events/repo/message.rs
@@ -2,6 +2,7 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
use crate::{
clock::DateTime,
+ events::types::{self, Sequence},
repo::{
channel::Channel,
login::{self, Login},
@@ -10,35 +11,25 @@ use crate::{
};
pub trait Provider {
- fn broadcast(&mut self) -> Broadcast;
+ fn message_events(&mut self) -> Events;
}
impl<'c> Provider for Transaction<'c, Sqlite> {
- fn broadcast(&mut self) -> Broadcast {
- Broadcast(self)
+ fn message_events(&mut self) -> Events {
+ Events(self)
}
}
-pub struct Broadcast<'t>(&'t mut SqliteConnection);
+pub struct Events<'t>(&'t mut SqliteConnection);
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct Message {
- pub id: message::Id,
- #[serde(skip)]
- pub sequence: Sequence,
- pub sender: Login,
- pub body: String,
- pub sent_at: DateTime,
-}
-
-impl<'c> Broadcast<'c> {
+impl<'c> Events<'c> {
pub async fn create(
&mut self,
sender: &Login,
channel: &Channel,
body: &str,
sent_at: &DateTime,
- ) -> Result<Message, sqlx::Error> {
+ ) -> Result<types::ChannelEvent, sqlx::Error> {
let sequence = self.next_sequence_for(channel).await?;
let id = message::Id::generate();
@@ -62,12 +53,16 @@ impl<'c> Broadcast<'c> {
body,
sent_at,
)
- .map(|row| Message {
- id: row.id,
+ .map(|row| types::ChannelEvent {
sequence: row.sequence,
- sender: sender.clone(),
- body: row.body,
- sent_at: row.sent_at,
+ at: row.sent_at,
+ channel: channel.clone(),
+ data: types::MessageEvent {
+ id: row.id,
+ sender: sender.clone(),
+ body: row.body,
+ }
+ .into(),
})
.fetch_one(&mut *self.0)
.await?;
@@ -76,7 +71,7 @@ impl<'c> Broadcast<'c> {
}
async fn next_sequence_for(&mut self, channel: &Channel) -> Result<Sequence, sqlx::Error> {
- let Sequence(current) = sqlx::query_scalar!(
+ let current = sqlx::query_scalar!(
r#"
-- `max` never returns null, but sqlx can't detect that
select max(sequence) as "sequence!: Sequence"
@@ -88,7 +83,7 @@ impl<'c> Broadcast<'c> {
.fetch_one(&mut *self.0)
.await?;
- Ok(Sequence(current + 1))
+ Ok(current.next())
}
pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> {
@@ -109,8 +104,8 @@ impl<'c> Broadcast<'c> {
&mut self,
channel: &Channel,
resume_at: Option<Sequence>,
- ) -> Result<Vec<Message>, sqlx::Error> {
- let messages = sqlx::query!(
+ ) -> Result<Vec<types::ChannelEvent>, sqlx::Error> {
+ let events = sqlx::query!(
r#"
select
message.id as "id: message::Id",
@@ -128,35 +123,23 @@ impl<'c> Broadcast<'c> {
channel.id,
resume_at,
)
- .map(|row| Message {
- id: row.id,
+ .map(|row| types::ChannelEvent {
sequence: row.sequence,
- sender: Login {
- id: row.sender_id,
- name: row.sender_name,
- },
- body: row.body,
- sent_at: row.sent_at,
+ at: row.sent_at,
+ channel: channel.clone(),
+ data: types::MessageEvent {
+ id: row.id,
+ sender: login::Login {
+ id: row.sender_id,
+ name: row.sender_name,
+ },
+ body: row.body,
+ }
+ .into(),
})
.fetch_all(&mut *self.0)
.await?;
- Ok(messages)
+ Ok(events)
}
}
-
-#[derive(
- Debug,
- Eq,
- Ord,
- PartialEq,
- PartialOrd,
- Clone,
- Copy,
- serde::Serialize,
- serde::Deserialize,
- sqlx::Type,
-)]
-#[serde(transparent)]
-#[sqlx(transparent)]
-pub struct Sequence(i64);
diff --git a/src/events/repo/mod.rs b/src/events/repo/mod.rs
index 2ed3062..e216a50 100644
--- a/src/events/repo/mod.rs
+++ b/src/events/repo/mod.rs
@@ -1 +1 @@
-pub mod broadcast;
+pub mod message;
diff --git a/src/events/routes.rs b/src/events/routes.rs
index d901f9b..3f70dcd 100644
--- a/src/events/routes.rs
+++ b/src/events/routes.rs
@@ -1,8 +1,5 @@
-use std::collections::{BTreeMap, HashSet};
-
use axum::{
extract::State,
- http::StatusCode,
response::{
sse::{self, Sse},
IntoResponse, Response,
@@ -10,87 +7,32 @@ use axum::{
routing::get,
Router,
};
-use axum_extra::extract::Query;
-use futures::{
- future,
- stream::{self, Stream, StreamExt as _, TryStreamExt as _},
-};
+use futures::stream::{Stream, StreamExt as _};
-use super::{extract::LastEventId, repo::broadcast};
-use crate::{
- app::App,
- clock::RequestedAt,
- error::Internal,
- events::app::EventsError,
- repo::{channel, login::Login},
+use super::{
+ extract::LastEventId,
+ types::{self, ResumePoint},
};
+use crate::{app::App, clock::RequestedAt, error::Internal, repo::login::Login};
#[cfg(test)]
mod test;
-// For the purposes of event replay, an "event ID" is a vector of per-channel
-// sequence numbers. Replay will start with messages whose sequence number in
-// its channel is higher than the sequence in the event ID, or if the channel
-// is not listed in the event ID, then at the beginning.
-//
-// Using a sorted map ensures that there is a canonical representation for
-// each event ID.
-type EventId = BTreeMap<channel::Id, broadcast::Sequence>;
-
pub fn router() -> Router<App> {
Router::new().route("/api/events", get(events))
}
-#[derive(Clone, serde::Deserialize)]
-struct EventsQuery {
- #[serde(default, rename = "channel")]
- channels: HashSet<channel::Id>,
-}
-
async fn events(
State(app): State<App>,
- RequestedAt(now): RequestedAt,
+ RequestedAt(subscribed_at): RequestedAt,
_: Login, // requires auth, but doesn't actually care who you are
- last_event_id: Option<LastEventId<EventId>>,
- Query(query): Query<EventsQuery>,
-) -> Result<Events<impl Stream<Item = ReplayableEvent> + std::fmt::Debug>, ErrorResponse> {
+ last_event_id: Option<LastEventId<ResumePoint>>,
+) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> {
let resume_at = last_event_id
.map(LastEventId::into_inner)
.unwrap_or_default();
- let streams = stream::iter(query.channels)
- .then(|channel| {
- let app = app.clone();
- let resume_at = resume_at.clone();
- async move {
- let resume_at = resume_at.get(&channel).copied();
-
- let events = app
- .events()
- .subscribe(&channel, &now, resume_at)
- .await?
- .map(ChannelEvent::wrap(channel));
-
- Ok::<_, EventsError>(events)
- }
- })
- .try_collect::<Vec<_>>()
- .await
- // impl From would take more code; this is used once.
- .map_err(ErrorResponse)?;
-
- // We resume counting from the provided last-event-id mapping, rather than
- // starting from scratch, so that the events in a resumed stream contain
- // the full vector of channel IDs for their event IDs right off the bat,
- // even before any events are actually delivered.
- let stream = stream::select_all(streams).scan(resume_at, |sequences, event| {
- let (channel, sequence) = event.event_id();
- sequences.insert(channel, sequence);
-
- let event = ReplayableEvent(sequences.clone(), event);
-
- future::ready(Some(event))
- });
+ let stream = app.events().subscribe(&subscribed_at, resume_at).await?;
Ok(Events(stream))
}
@@ -100,7 +42,7 @@ struct Events<S>(S);
impl<S> IntoResponse for Events<S>
where
- S: Stream<Item = ReplayableEvent> + Send + 'static,
+ S: Stream<Item = types::ResumableEvent> + Send + 'static,
{
fn into_response(self) -> Response {
let Self(stream) = self;
@@ -111,51 +53,13 @@ where
}
}
-#[derive(Debug)]
-struct ErrorResponse(EventsError);
-
-impl IntoResponse for ErrorResponse {
- fn into_response(self) -> Response {
- let Self(error) = self;
- match error {
- not_found @ EventsError::ChannelNotFound(_) => {
- (StatusCode::NOT_FOUND, not_found.to_string()).into_response()
- }
- other => Internal::from(other).into_response(),
- }
- }
-}
-
-#[derive(Debug)]
-struct ReplayableEvent(EventId, ChannelEvent);
-
-#[derive(Debug, serde::Serialize)]
-struct ChannelEvent {
- channel: channel::Id,
- #[serde(flatten)]
- message: broadcast::Message,
-}
-
-impl ChannelEvent {
- fn wrap(channel: channel::Id) -> impl Fn(broadcast::Message) -> Self {
- move |message| Self {
- channel: channel.clone(),
- message,
- }
- }
-
- fn event_id(&self) -> (channel::Id, broadcast::Sequence) {
- (self.channel.clone(), self.message.sequence)
- }
-}
-
-impl TryFrom<ReplayableEvent> for sse::Event {
+impl TryFrom<types::ResumableEvent> for sse::Event {
type Error = serde_json::Error;
- fn try_from(value: ReplayableEvent) -> Result<Self, Self::Error> {
- let ReplayableEvent(id, data) = value;
+ fn try_from(value: types::ResumableEvent) -> Result<Self, Self::Error> {
+ let types::ResumableEvent(resume_at, data) = value;
- let id = serde_json::to_string(&id)?;
+ let id = serde_json::to_string(&resume_at)?;
let data = serde_json::to_string_pretty(&data)?;
let event = Self::default().id(id).data(data);
diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs
index 4412938..f289225 100644
--- a/src/events/routes/test.rs
+++ b/src/events/routes/test.rs
@@ -1,40 +1,15 @@
use axum::extract::State;
-use axum_extra::extract::Query;
use futures::{
future,
stream::{self, StreamExt as _},
};
use crate::{
- events::{app, routes},
- repo::channel::{self},
+ events::{routes, types},
test::fixtures::{self, future::Immediately as _},
};
#[tokio::test]
-async fn no_subscriptions() {
- // Set up the environment
-
- let app = fixtures::scratch_app().await;
- let subscriber = fixtures::login::create(&app).await;
-
- // Call the endpoint
-
- let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [].into(),
- };
- let routes::Events(mut events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("empty subscription");
-
- // Verify the structure of the response.
-
- assert!(events.next().immediately().await.is_none());
-}
-
-#[tokio::test]
async fn includes_historical_message() {
// Set up the environment
@@ -47,24 +22,19 @@ async fn includes_historical_message() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
- let routes::Events(mut events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to valid channel");
+ let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the structure of the response.
- let routes::ReplayableEvent(_, event) = events
+ let types::ResumableEvent(_, event) = events
.next()
.immediately()
.await
.expect("delivered stored message");
- assert_eq!(channel.id, event.channel);
- assert_eq!(message, event.message);
+ assert_eq!(message, event);
}
#[tokio::test]
@@ -78,68 +48,23 @@ async fn includes_live_message() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
- let routes::Events(mut events) = routes::events(
- State(app.clone()),
- subscribed_at,
- subscriber,
- None,
- Query(query),
- )
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(mut events) =
+ routes::events(State(app.clone()), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the semantics
let sender = fixtures::login::create(&app).await;
let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await;
- let routes::ReplayableEvent(_, event) = events
+ let types::ResumableEvent(_, event) = events
.next()
.immediately()
.await
.expect("delivered live message");
- assert_eq!(channel.id, event.channel);
- assert_eq!(message, event.message);
-}
-
-#[tokio::test]
-async fn excludes_other_channels() {
- // Set up the environment
-
- let app = fixtures::scratch_app().await;
- let subscribed_channel = fixtures::channel::create(&app).await;
- let unsubscribed_channel = fixtures::channel::create(&app).await;
- let sender = fixtures::login::create(&app).await;
- let message =
- fixtures::message::send(&app, &sender, &subscribed_channel, &fixtures::now()).await;
- fixtures::message::send(&app, &sender, &unsubscribed_channel, &fixtures::now()).await;
-
- // Call the endpoint
-
- let subscriber = fixtures::login::create(&app).await;
- let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [subscribed_channel.id.clone()].into(),
- };
- let routes::Events(mut events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to a valid channel");
-
- // Verify the semantics
-
- let routes::ReplayableEvent(_, event) = events
- .next()
- .immediately()
- .await
- .expect("delivered at least one message");
-
- assert_eq!(subscribed_channel.id, event.channel);
- assert_eq!(message, event.message);
+ assert_eq!(message, event);
}
#[tokio::test]
@@ -155,10 +80,11 @@ async fn includes_multiple_channels() {
];
let messages = stream::iter(channels)
- .then(|channel| async {
- let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await;
-
- (channel, message)
+ .then(|channel| {
+ let app = app.clone();
+ let sender = sender.clone();
+ let channel = channel.clone();
+ async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await }
})
.collect::<Vec<_>>()
.await;
@@ -167,17 +93,9 @@ async fn includes_multiple_channels() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: messages
- .iter()
- .map(|(channel, _)| &channel.id)
- .cloned()
- .collect(),
- };
- let routes::Events(events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to valid channels");
+ let routes::Events(events) = routes::events(State(app), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the structure of the response.
@@ -187,41 +105,14 @@ async fn includes_multiple_channels() {
.immediately()
.await;
- for (channel, message) in messages {
- assert!(events.iter().any(|routes::ReplayableEvent(_, event)| {
- event.channel == channel.id && event.message == message
- }));
+ for message in &messages {
+ assert!(events
+ .iter()
+ .any(|types::ResumableEvent(_, event)| { event == message }));
}
}
#[tokio::test]
-async fn nonexistent_channel() {
- // Set up the environment
-
- let app = fixtures::scratch_app().await;
- let channel = channel::Id::generate();
-
- // Call the endpoint
-
- let subscriber = fixtures::login::create(&app).await;
- let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.clone()].into(),
- };
- let routes::ErrorResponse(error) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect_err("subscribed to nonexistent channel");
-
- // Verify the structure of the response.
-
- assert!(matches!(
- error,
- app::EventsError::ChannelNotFound(error_channel) if error_channel == channel
- ));
-}
-
-#[tokio::test]
async fn sequential_messages() {
// Set up the environment
@@ -239,30 +130,24 @@ async fn sequential_messages() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
- let routes::Events(events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(events) = routes::events(State(app), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the structure of the response.
- let mut events = events.filter(|routes::ReplayableEvent(_, event)| {
- future::ready(messages.contains(&event.message))
- });
+ let mut events =
+ events.filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event)));
// Verify delivery in order
for message in &messages {
- let routes::ReplayableEvent(_, event) = events
+ let types::ResumableEvent(_, event) = events
.next()
.immediately()
.await
.expect("undelivered messages remaining");
- assert_eq!(channel.id, event.channel);
- assert_eq!(message, &event.message);
+ assert_eq!(message, &event);
}
}
@@ -285,42 +170,28 @@ async fn resumes_from() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
let resume_at = {
// First subscription
- let routes::Events(mut events) = routes::events(
- State(app.clone()),
- subscribed_at,
- subscriber.clone(),
- None,
- Query(query.clone()),
- )
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(mut events) =
+ routes::events(State(app.clone()), subscribed_at, subscriber.clone(), None)
+ .await
+ .expect("subscribe never fails");
- let routes::ReplayableEvent(id, event) =
+ let types::ResumableEvent(last_event_id, event) =
events.next().immediately().await.expect("delivered events");
- assert_eq!(channel.id, event.channel);
- assert_eq!(initial_message, event.message);
+ assert_eq!(initial_message, event);
- id
+ last_event_id
};
// Resume after disconnect
let reconnect_at = fixtures::now();
- let routes::Events(resumed) = routes::events(
- State(app),
- reconnect_at,
- subscriber,
- Some(resume_at.into()),
- Query(query),
- )
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(resumed) =
+ routes::events(State(app), reconnect_at, subscriber, Some(resume_at.into()))
+ .await
+ .expect("subscribe never fails");
// Verify the structure of the response.
@@ -330,11 +201,10 @@ async fn resumes_from() {
.immediately()
.await;
- for message in later_messages {
- assert!(events.iter().any(
- |routes::ReplayableEvent(_, event)| event.channel == channel.id
- && event.message == message
- ));
+ for message in &later_messages {
+ assert!(events
+ .iter()
+ .any(|types::ResumableEvent(_, event)| event == message));
}
}
@@ -365,9 +235,6 @@ async fn serial_resume() {
// Call the endpoint
let subscriber = fixtures::login::create(&app).await;
- let query = routes::EventsQuery {
- channels: [channel_a.id.clone(), channel_b.id.clone()].into(),
- };
let resume_at = {
let initial_messages = [
@@ -377,15 +244,10 @@ async fn serial_resume() {
// First subscription
let subscribed_at = fixtures::now();
- let routes::Events(events) = routes::events(
- State(app.clone()),
- subscribed_at,
- subscriber.clone(),
- None,
- Query(query.clone()),
- )
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(events) =
+ routes::events(State(app.clone()), subscribed_at, subscriber.clone(), None)
+ .await
+ .expect("subscribe never fails");
let events = events
.take(initial_messages.len())
@@ -393,13 +255,13 @@ async fn serial_resume() {
.immediately()
.await;
- for message in initial_messages {
+ for message in &initial_messages {
assert!(events
.iter()
- .any(|routes::ReplayableEvent(_, event)| event.message == message));
+ .any(|types::ResumableEvent(_, event)| event == message));
}
- let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty");
+ let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty");
id.to_owned()
};
@@ -421,10 +283,9 @@ async fn serial_resume() {
resubscribed_at,
subscriber.clone(),
Some(resume_at.into()),
- Query(query.clone()),
)
.await
- .expect("subscribed to a valid channel");
+ .expect("subscribe never fails");
let events = events
.take(resume_messages.len())
@@ -432,13 +293,13 @@ async fn serial_resume() {
.immediately()
.await;
- for message in resume_messages {
+ for message in &resume_messages {
assert!(events
.iter()
- .any(|routes::ReplayableEvent(_, event)| event.message == message));
+ .any(|types::ResumableEvent(_, event)| event == message));
}
- let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty");
+ let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty");
id.to_owned()
};
@@ -460,10 +321,9 @@ async fn serial_resume() {
resubscribed_at,
subscriber.clone(),
Some(resume_at.into()),
- Query(query.clone()),
)
.await
- .expect("subscribed to a valid channel");
+ .expect("subscribe never fails");
let events = events
.take(final_messages.len())
@@ -473,10 +333,10 @@ async fn serial_resume() {
// This set of messages, in particular, _should not_ include any prior
// messages from `initial_messages` or `resume_messages`.
- for message in final_messages {
+ for message in &final_messages {
assert!(events
.iter()
- .any(|routes::ReplayableEvent(_, event)| event.message == message));
+ .any(|types::ResumableEvent(_, event)| event == message));
}
};
}
@@ -495,22 +355,18 @@ async fn removes_expired_messages() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
- let routes::Events(mut events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to valid channel");
+
+ let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the semantics
- let routes::ReplayableEvent(_, event) = events
+ let types::ResumableEvent(_, event) = events
.next()
.immediately()
.await
.expect("delivered messages");
- assert_eq!(channel.id, event.channel);
- assert_eq!(message, event.message);
+ assert_eq!(message, event);
}
diff --git a/src/events/types.rs b/src/events/types.rs
new file mode 100644
index 0000000..6747afc
--- /dev/null
+++ b/src/events/types.rs
@@ -0,0 +1,99 @@
+use std::collections::BTreeMap;
+
+use crate::{
+ clock::DateTime,
+ repo::{
+ channel::{self, Channel},
+ login::Login,
+ message,
+ },
+};
+
+#[derive(
+ Debug,
+ Eq,
+ Ord,
+ PartialEq,
+ PartialOrd,
+ Clone,
+ Copy,
+ serde::Serialize,
+ serde::Deserialize,
+ sqlx::Type,
+)]
+#[serde(transparent)]
+#[sqlx(transparent)]
+pub struct Sequence(i64);
+
+impl Sequence {
+ pub fn next(self) -> Self {
+ let Self(current) = self;
+ Self(current + 1)
+ }
+}
+
+// For the purposes of event replay, an "event ID" is a vector of per-channel
+// sequence numbers. Replay will start with messages whose sequence number in
+// its channel is higher than the sequence in the event ID, or if the channel
+// is not listed in the event ID, then at the beginning.
+//
+// Using a sorted map ensures that there is a canonical representation for
+// each event ID.
+#[derive(Clone, Debug, Default, PartialEq, PartialOrd, serde::Deserialize, serde::Serialize)]
+#[serde(transparent)]
+pub struct ResumePoint(BTreeMap<channel::Id, Sequence>);
+
+impl ResumePoint {
+ pub fn singleton(channel: &channel::Id, sequence: Sequence) -> Self {
+ let mut vector = Self::default();
+ vector.advance(channel, sequence);
+ vector
+ }
+
+ pub fn advance(&mut self, channel: &channel::Id, sequence: Sequence) {
+ let Self(elements) = self;
+ elements.insert(channel.clone(), sequence);
+ }
+
+ pub fn get(&self, channel: &channel::Id) -> Option<Sequence> {
+ let Self(elements) = self;
+ elements.get(channel).copied()
+ }
+}
+#[derive(Clone, Debug)]
+pub struct ResumableEvent(pub ResumePoint, pub ChannelEvent);
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct ChannelEvent {
+ #[serde(skip)]
+ pub sequence: Sequence,
+ pub at: DateTime,
+ pub channel: Channel,
+ #[serde(flatten)]
+ pub data: ChannelEventData,
+}
+
+impl ChannelEvent {
+ pub fn sequence(&self) -> ResumePoint {
+ ResumePoint::singleton(&self.channel.id, self.sequence)
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ChannelEventData {
+ Message(MessageEvent),
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct MessageEvent {
+ pub id: message::Id,
+ pub sender: Login,
+ pub body: String,
+}
+
+impl From<MessageEvent> for ChannelEventData {
+ fn from(message: MessageEvent) -> Self {
+ Self::Message(message)
+ }
+}
diff --git a/src/repo/channel.rs b/src/repo/channel.rs
index 0186413..d223dab 100644
--- a/src/repo/channel.rs
+++ b/src/repo/channel.rs
@@ -16,7 +16,7 @@ impl<'c> Provider for Transaction<'c, Sqlite> {
pub struct Channels<'t>(&'t mut SqliteConnection);
-#[derive(Debug, Eq, PartialEq, serde::Serialize)]
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
pub struct Channel {
pub id: Id,
pub name: String,
diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs
index 33feeae..bfca8cd 100644
--- a/src/test/fixtures/message.rs
+++ b/src/test/fixtures/message.rs
@@ -3,7 +3,7 @@ use faker_rand::lorem::Paragraphs;
use crate::{
app::App,
clock::RequestedAt,
- events::repo::broadcast,
+ events::types,
repo::{channel::Channel, login::Login},
};
@@ -12,7 +12,7 @@ pub async fn send(
login: &Login,
channel: &Channel,
sent_at: &RequestedAt,
-) -> broadcast::Message {
+) -> types::ChannelEvent {
let body = propose();
app.events()
diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs
index a42dba5..450fbec 100644
--- a/src/test/fixtures/mod.rs
+++ b/src/test/fixtures/mod.rs
@@ -13,8 +13,6 @@ pub async fn scratch_app() -> App {
.await
.expect("setting up in-memory sqlite database");
App::from(pool)
- .await
- .expect("creating an app from a fresh, in-memory database")
}
pub fn now() -> RequestedAt {