summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/api.md32
-rw-r--r--src/app.rs8
-rw-r--r--src/channel/app.rs11
-rw-r--r--src/channel/routes/test/on_send.rs89
-rw-r--r--src/cli.rs2
-rw-r--r--src/events/app.rs93
-rw-r--r--src/events/broadcaster.rs77
-rw-r--r--src/events/mod.rs1
-rw-r--r--src/events/repo/message.rs (renamed from src/events/repo/broadcast.rs)83
-rw-r--r--src/events/repo/mod.rs2
-rw-r--r--src/events/routes.rs124
-rw-r--r--src/events/routes/test.rs276
-rw-r--r--src/events/types.rs99
-rw-r--r--src/repo/channel.rs2
-rw-r--r--src/test/fixtures/message.rs4
-rw-r--r--src/test/fixtures/mod.rs2
16 files changed, 321 insertions, 584 deletions
diff --git a/docs/api.md b/docs/api.md
index 8bb3c0b..8b31941 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -152,18 +152,6 @@ Subscribes to events. This endpoint returns an `application/event-stream` respon
The returned stream may terminate, to limit the number of outstanding messages held by the server. Clients can and should repeat the request, using the `Last-Event-Id` header to resume from where they left off. Events will be replayed from that point, and the stream will resume.
-#### Query parameters
-
-This endpoint accepts the following query parameters:
-
-* `channel`: a channel ID to subscribe to. Events for this channel will be included in the response. This parameter may be provided multiple times. Clients should not subscribe to the same channel more than once in a single request.
-
-Browsers generally limit the number of open connections, often to embarrassingly low limits. Clients should subscribe to multiple streams in a single request, and should not subscribe to each stream individually.
-
-Requests without a subscription return an empty stream.
-
-(If you're wondering: it has to be query parameters or something equivalent to it, since `EventSource` can only issue `GET` requests.)
-
#### Request headers
This endpoint accepts an optional `Last-Event-Id` header for resuming an interrupted stream. If this header is provided, it must be set to the `id` field sent with the last event the client has processed. When `Last-Event-Id` is sent, the response will resume immediately after the corresponding event. If this header is omitted, then the stream will start from the beginning.
@@ -179,13 +167,17 @@ The returned event stream is a sequence of events:
```json
id: 1234
data: {
-data: "channel": "C9876cyyz",
-data: "id": "Mabcd1234",
-data: "sender": {
-data: "id": "L1234abcd",
-data: "name": "example username"
-data: },
-data: "body": "my amazing thoughts, by bob",
-data: "sent_at": "2024-09-19T02:30:50.915462Z"
+data: "type": "message",
+data: "at": "2024-09-27T23:19:10.208147Z",
+data: "id": "Mxnjcf3y41prfry9",
+data: "channel": {
+data: "id": "C9876cyyz",
+data: "name": "example channel 2"
+data: },
+data: "sender": {
+data: "id": "L1234abcd",
+data: "name": "example username"
+data: },
+data: "body": "beep"
data: }
```
diff --git a/src/app.rs b/src/app.rs
index b2f861c..07b932a 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -13,9 +13,9 @@ pub struct App {
}
impl App {
- pub async fn from(db: SqlitePool) -> Result<Self, sqlx::Error> {
- let broadcaster = Broadcaster::from_database(&db).await?;
- Ok(Self { db, broadcaster })
+ pub fn from(db: SqlitePool) -> Self {
+ let broadcaster = Broadcaster::default();
+ Self { db, broadcaster }
}
}
@@ -29,6 +29,6 @@ impl App {
}
pub const fn channels(&self) -> Channels {
- Channels::new(&self.db, &self.broadcaster)
+ Channels::new(&self.db)
}
}
diff --git a/src/channel/app.rs b/src/channel/app.rs
index 793fa35..6bad158 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -1,18 +1,14 @@
use sqlx::sqlite::SqlitePool;
-use crate::{
- events::broadcaster::Broadcaster,
- repo::channel::{Channel, Provider as _},
-};
+use crate::repo::channel::{Channel, Provider as _};
pub struct Channels<'a> {
db: &'a SqlitePool,
- broadcaster: &'a Broadcaster,
}
impl<'a> Channels<'a> {
- pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self {
- Self { db, broadcaster }
+ pub const fn new(db: &'a SqlitePool) -> Self {
+ Self { db }
}
pub async fn create(&self, name: &str) -> Result<Channel, CreateError> {
@@ -22,7 +18,6 @@ impl<'a> Channels<'a> {
.create(name)
.await
.map_err(|err| CreateError::from_duplicate_name(err, name))?;
- self.broadcaster.register_channel(&channel.id);
tx.commit().await?;
Ok(channel)
diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs
index 93a5480..5d87bdc 100644
--- a/src/channel/routes/test/on_send.rs
+++ b/src/channel/routes/test/on_send.rs
@@ -1,65 +1,14 @@
-use axum::{
- extract::{Json, Path, State},
- http::StatusCode,
-};
+use axum::extract::{Json, Path, State};
use futures::stream::StreamExt;
use crate::{
channel::routes,
- events::app,
+ events::{app, types},
repo::channel,
test::fixtures::{self, future::Immediately as _},
};
#[tokio::test]
-async fn channel_exists() {
- // Set up the environment
-
- let app = fixtures::scratch_app().await;
- let sender = fixtures::login::create(&app).await;
- let channel = fixtures::channel::create(&app).await;
-
- // Call the endpoint
-
- let sent_at = fixtures::now();
- let request = routes::SendRequest {
- message: fixtures::message::propose(),
- };
- let status = routes::on_send(
- State(app.clone()),
- Path(channel.id.clone()),
- sent_at.clone(),
- sender.clone(),
- Json(request.clone()),
- )
- .await
- .expect("sending to a valid channel");
-
- // Verify the structure of the response
-
- assert_eq!(StatusCode::ACCEPTED, status);
-
- // Verify the semantics
-
- let subscribed_at = fixtures::now();
- let mut events = app
- .events()
- .subscribe(&channel.id, &subscribed_at, None)
- .await
- .expect("subscribing to a valid channel");
-
- let event = events
- .next()
- .immediately()
- .await
- .expect("event received by subscribers");
-
- assert_eq!(request.message, event.body);
- assert_eq!(sender, event.sender);
- assert_eq!(*sent_at, event.sent_at);
-}
-
-#[tokio::test]
async fn messages_in_order() {
// Set up the environment
@@ -70,21 +19,15 @@ async fn messages_in_order() {
// Call the endpoint (twice)
let requests = vec![
- (
- fixtures::now(),
- routes::SendRequest {
- message: fixtures::message::propose(),
- },
- ),
- (
- fixtures::now(),
- routes::SendRequest {
- message: fixtures::message::propose(),
- },
- ),
+ (fixtures::now(), fixtures::message::propose()),
+ (fixtures::now(), fixtures::message::propose()),
];
- for (sent_at, request) in &requests {
+ for (sent_at, message) in &requests {
+ let request = routes::SendRequest {
+ message: message.clone(),
+ };
+
routes::on_send(
State(app.clone()),
Path(channel.id.clone()),
@@ -101,17 +44,21 @@ async fn messages_in_order() {
let subscribed_at = fixtures::now();
let events = app
.events()
- .subscribe(&channel.id, &subscribed_at, None)
+ .subscribe(&subscribed_at, types::ResumePoint::default())
.await
.expect("subscribing to a valid channel")
.take(requests.len());
let events = events.collect::<Vec<_>>().immediately().await;
- for ((sent_at, request), event) in requests.into_iter().zip(events) {
- assert_eq!(request.message, event.body);
- assert_eq!(sender, event.sender);
- assert_eq!(*sent_at, event.sent_at);
+ for ((sent_at, message), types::ResumableEvent(_, event)) in requests.into_iter().zip(events) {
+ assert_eq!(*sent_at, event.at);
+ assert!(matches!(
+ event.data,
+ types::ChannelEventData::Message(event_message)
+ if event_message.sender == sender
+ && event_message.body == message
+ ));
}
}
diff --git a/src/cli.rs b/src/cli.rs
index b147f7d..a6d752c 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -70,7 +70,7 @@ impl Args {
pub async fn run(self) -> Result<(), Error> {
let pool = self.pool().await?;
- let app = App::from(pool).await?;
+ let app = App::from(pool);
let app = routers()
.route_layer(middleware::from_fn(clock::middleware))
.with_state(app);
diff --git a/src/events/app.rs b/src/events/app.rs
index 7229551..043a29b 100644
--- a/src/events/app.rs
+++ b/src/events/app.rs
@@ -1,3 +1,5 @@
+use std::collections::BTreeMap;
+
use chrono::TimeDelta;
use futures::{
future,
@@ -8,7 +10,8 @@ use sqlx::sqlite::SqlitePool;
use super::{
broadcaster::Broadcaster,
- repo::broadcast::{self, Provider as _},
+ repo::message::Provider as _,
+ types::{self, ResumePoint},
};
use crate::{
clock::DateTime,
@@ -35,64 +38,56 @@ impl<'a> Events<'a> {
channel: &channel::Id,
body: &str,
sent_at: &DateTime,
- ) -> Result<broadcast::Message, EventsError> {
+ ) -> Result<types::ChannelEvent, EventsError> {
let mut tx = self.db.begin().await?;
let channel = tx
.channels()
.by_id(channel)
.await
.not_found(|| EventsError::ChannelNotFound(channel.clone()))?;
- let message = tx
- .broadcast()
+ let event = tx
+ .message_events()
.create(login, &channel, body, sent_at)
.await?;
tx.commit().await?;
- self.broadcaster.broadcast(&channel.id, &message);
- Ok(message)
+ self.broadcaster.broadcast(&event);
+ Ok(event)
}
pub async fn subscribe(
&self,
- channel: &channel::Id,
subscribed_at: &DateTime,
- resume_at: Option<broadcast::Sequence>,
- ) -> Result<impl Stream<Item = broadcast::Message> + std::fmt::Debug, EventsError> {
+ resume_at: ResumePoint,
+ ) -> Result<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug, sqlx::Error> {
// Somewhat arbitrarily, expire after 90 days.
let expire_at = subscribed_at.to_owned() - TimeDelta::days(90);
let mut tx = self.db.begin().await?;
- let channel = tx
- .channels()
- .by_id(channel)
- .await
- .not_found(|| EventsError::ChannelNotFound(channel.clone()))?;
+ let channels = tx.channels().all().await?;
// Subscribe before retrieving, to catch messages broadcast while we're
// querying the DB. We'll prune out duplicates later.
- let live_messages = self.broadcaster.subscribe(&channel.id);
+ let live_messages = self.broadcaster.subscribe();
- tx.broadcast().expire(&expire_at).await?;
- let stored_messages = tx.broadcast().replay(&channel, resume_at).await?;
- tx.commit().await?;
+ tx.message_events().expire(&expire_at).await?;
- let resume_broadcast_at = stored_messages
- .last()
- .map(|message| message.sequence)
- .or(resume_at);
+ let mut replays = BTreeMap::new();
+ let mut resume_live_at = resume_at.clone();
+ for channel in channels {
+ let replay = tx
+ .message_events()
+ .replay(&channel, resume_at.get(&channel.id))
+ .await?;
- // This should always be the case, up to integer rollover, primarily
- // because every message in stored_messages has a sequence not less
- // than `resume_at`, or `resume_at` is None. We use the last message
- // (if any) to decide when to resume the `live_messages` stream.
- //
- // It probably simplifies to assert!(resume_at <= resume_broadcast_at), but
- // this form captures more of the reasoning.
- assert!(
- (resume_at.is_none() && resume_broadcast_at.is_none())
- || (stored_messages.is_empty() && resume_at == resume_broadcast_at)
- || resume_at < resume_broadcast_at
- );
+ if let Some(last) = replay.last() {
+ resume_live_at.advance(&channel.id, last.sequence);
+ }
+
+ replays.insert(channel.id.clone(), replay);
+ }
+
+ let replay = stream::select_all(replays.into_values().map(stream::iter));
// no skip_expired or resume transforms for stored_messages, as it's
// constructed not to contain messages meeting either criterion.
@@ -100,7 +95,6 @@ impl<'a> Events<'a> {
// * skip_expired is redundant with the `tx.broadcasts().expire(…)` call;
// * resume is redundant with the resume_at argument to
// `tx.broadcasts().replay(…)`.
- let stored_messages = stream::iter(stored_messages);
let live_messages = live_messages
// Sure, it's temporally improbable that we'll ever skip a message
// that's 90 days old, but there's no reason not to be thorough.
@@ -108,26 +102,31 @@ impl<'a> Events<'a> {
// Filtering on the broadcast resume point filters out messages
// before resume_at, and filters out messages duplicated from
// stored_messages.
- .filter(Self::resume(resume_broadcast_at));
+ .filter(Self::resume(resume_live_at));
- Ok(stored_messages.chain(live_messages))
+ Ok(replay
+ .chain(live_messages)
+ .scan(resume_at, |resume_point, event| {
+ let channel = &event.channel.id;
+ let sequence = event.sequence;
+ resume_point.advance(channel, sequence);
+
+ let event = types::ResumableEvent(resume_point.clone(), event);
+
+ future::ready(Some(event))
+ }))
}
fn resume(
- resume_at: Option<broadcast::Sequence>,
- ) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready<bool> {
- move |msg| {
- future::ready(match resume_at {
- None => true,
- Some(resume_at) => msg.sequence > resume_at,
- })
- }
+ resume_at: ResumePoint,
+ ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> {
+ move |event| future::ready(resume_at < event.sequence())
}
fn skip_expired(
expire_at: &DateTime,
- ) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready<bool> {
+ ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> {
let expire_at = expire_at.to_owned();
- move |msg| future::ready(msg.sent_at > expire_at)
+ move |event| future::ready(expire_at < event.at)
}
}
diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs
index dcaba91..9697c0a 100644
--- a/src/events/broadcaster.rs
+++ b/src/events/broadcaster.rs
@@ -1,63 +1,35 @@
-use std::collections::{hash_map::Entry, HashMap};
-use std::sync::{Arc, Mutex, MutexGuard};
+use std::sync::{Arc, Mutex};
use futures::{future, stream::StreamExt as _, Stream};
-use sqlx::sqlite::SqlitePool;
use tokio::sync::broadcast::{channel, Sender};
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
-use crate::{
- events::repo::broadcast,
- repo::channel::{self, Provider as _},
-};
+use crate::events::types;
-// Clones will share the same senders collection.
+// Clones will share the same sender.
#[derive(Clone)]
pub struct Broadcaster {
// The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
// own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
// lock it must be sync.
- senders: Arc<Mutex<HashMap<channel::Id, Sender<broadcast::Message>>>>,
+ senders: Arc<Mutex<Sender<types::ChannelEvent>>>,
}
-impl Broadcaster {
- pub async fn from_database(db: &SqlitePool) -> Result<Self, sqlx::Error> {
- let mut tx = db.begin().await?;
- let channels = tx.channels().all().await?;
- tx.commit().await?;
-
- let channels = channels.iter().map(|c| &c.id);
- let broadcaster = Self::new(channels);
- Ok(broadcaster)
- }
-
- fn new<'i>(channels: impl IntoIterator<Item = &'i channel::Id>) -> Self {
- let senders: HashMap<_, _> = channels
- .into_iter()
- .cloned()
- .map(|id| (id, Self::make_sender()))
- .collect();
+impl Default for Broadcaster {
+ fn default() -> Self {
+ let sender = Self::make_sender();
Self {
- senders: Arc::new(Mutex::new(senders)),
- }
- }
-
- // panic: if ``channel`` is already registered.
- pub fn register_channel(&self, channel: &channel::Id) {
- match self.senders().entry(channel.clone()) {
- // This ever happening indicates a serious logic error.
- Entry::Occupied(_) => panic!("duplicate channel registration for channel {channel}"),
- Entry::Vacant(entry) => {
- entry.insert(Self::make_sender());
- }
+ senders: Arc::new(Mutex::new(sender)),
}
}
+}
- // panic: if ``channel`` has not been previously registered, and was not
- // part of the initial set of channels.
- pub fn broadcast(&self, channel: &channel::Id, message: &broadcast::Message) {
- let tx = self.sender(channel);
+impl Broadcaster {
+ // panic: if ``message.channel.id`` has not been previously registered,
+ // and was not part of the initial set of channels.
+ pub fn broadcast(&self, message: &types::ChannelEvent) {
+ let tx = self.sender();
// Per the Tokio docs, the returned error is only used to indicate that
// there are no receivers. In this use case, that's fine; a lack of
@@ -71,15 +43,12 @@ impl Broadcaster {
// panic: if ``channel`` has not been previously registered, and was not
// part of the initial set of channels.
- pub fn subscribe(
- &self,
- channel: &channel::Id,
- ) -> impl Stream<Item = broadcast::Message> + std::fmt::Debug {
- let rx = self.sender(channel).subscribe();
+ pub fn subscribe(&self) -> impl Stream<Item = types::ChannelEvent> + std::fmt::Debug {
+ let rx = self.sender().subscribe();
BroadcastStream::from(rx).scan((), |(), r| {
future::ready(match r {
- Ok(message) => Some(message),
+ Ok(event) => Some(event),
// Stop the stream here. This will disconnect SSE clients
// (see `routes.rs`), who will then resume from
// `Last-Event-ID`, allowing them to catch up by reading
@@ -92,17 +61,11 @@ impl Broadcaster {
})
}
- // panic: if ``channel`` has not been previously registered, and was not
- // part of the initial set of channels.
- fn sender(&self, channel: &channel::Id) -> Sender<broadcast::Message> {
- self.senders()[channel].clone()
- }
-
- fn senders(&self) -> MutexGuard<HashMap<channel::Id, Sender<broadcast::Message>>> {
- self.senders.lock().unwrap() // propagate panics when mutex is poisoned
+ fn sender(&self) -> Sender<types::ChannelEvent> {
+ self.senders.lock().unwrap().clone()
}
- fn make_sender() -> Sender<broadcast::Message> {
+ fn make_sender() -> Sender<types::ChannelEvent> {
// Queue depth of 16 chosen entirely arbitrarily. Don't read too much
// into it.
let (tx, _) = channel(16);
diff --git a/src/events/mod.rs b/src/events/mod.rs
index b9f3f5b..711ae64 100644
--- a/src/events/mod.rs
+++ b/src/events/mod.rs
@@ -3,5 +3,6 @@ pub mod broadcaster;
mod extract;
pub mod repo;
mod routes;
+pub mod types;
pub use self::routes::router;
diff --git a/src/events/repo/broadcast.rs b/src/events/repo/message.rs
index 6914573..b4724ea 100644
--- a/src/events/repo/broadcast.rs
+++ b/src/events/repo/message.rs
@@ -2,6 +2,7 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
use crate::{
clock::DateTime,
+ events::types::{self, Sequence},
repo::{
channel::Channel,
login::{self, Login},
@@ -10,35 +11,25 @@ use crate::{
};
pub trait Provider {
- fn broadcast(&mut self) -> Broadcast;
+ fn message_events(&mut self) -> Events;
}
impl<'c> Provider for Transaction<'c, Sqlite> {
- fn broadcast(&mut self) -> Broadcast {
- Broadcast(self)
+ fn message_events(&mut self) -> Events {
+ Events(self)
}
}
-pub struct Broadcast<'t>(&'t mut SqliteConnection);
+pub struct Events<'t>(&'t mut SqliteConnection);
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct Message {
- pub id: message::Id,
- #[serde(skip)]
- pub sequence: Sequence,
- pub sender: Login,
- pub body: String,
- pub sent_at: DateTime,
-}
-
-impl<'c> Broadcast<'c> {
+impl<'c> Events<'c> {
pub async fn create(
&mut self,
sender: &Login,
channel: &Channel,
body: &str,
sent_at: &DateTime,
- ) -> Result<Message, sqlx::Error> {
+ ) -> Result<types::ChannelEvent, sqlx::Error> {
let sequence = self.next_sequence_for(channel).await?;
let id = message::Id::generate();
@@ -62,12 +53,16 @@ impl<'c> Broadcast<'c> {
body,
sent_at,
)
- .map(|row| Message {
- id: row.id,
+ .map(|row| types::ChannelEvent {
sequence: row.sequence,
- sender: sender.clone(),
- body: row.body,
- sent_at: row.sent_at,
+ at: row.sent_at,
+ channel: channel.clone(),
+ data: types::MessageEvent {
+ id: row.id,
+ sender: sender.clone(),
+ body: row.body,
+ }
+ .into(),
})
.fetch_one(&mut *self.0)
.await?;
@@ -76,7 +71,7 @@ impl<'c> Broadcast<'c> {
}
async fn next_sequence_for(&mut self, channel: &Channel) -> Result<Sequence, sqlx::Error> {
- let Sequence(current) = sqlx::query_scalar!(
+ let current = sqlx::query_scalar!(
r#"
-- `max` never returns null, but sqlx can't detect that
select max(sequence) as "sequence!: Sequence"
@@ -88,7 +83,7 @@ impl<'c> Broadcast<'c> {
.fetch_one(&mut *self.0)
.await?;
- Ok(Sequence(current + 1))
+ Ok(current.next())
}
pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> {
@@ -109,8 +104,8 @@ impl<'c> Broadcast<'c> {
&mut self,
channel: &Channel,
resume_at: Option<Sequence>,
- ) -> Result<Vec<Message>, sqlx::Error> {
- let messages = sqlx::query!(
+ ) -> Result<Vec<types::ChannelEvent>, sqlx::Error> {
+ let events = sqlx::query!(
r#"
select
message.id as "id: message::Id",
@@ -128,35 +123,23 @@ impl<'c> Broadcast<'c> {
channel.id,
resume_at,
)
- .map(|row| Message {
- id: row.id,
+ .map(|row| types::ChannelEvent {
sequence: row.sequence,
- sender: Login {
- id: row.sender_id,
- name: row.sender_name,
- },
- body: row.body,
- sent_at: row.sent_at,
+ at: row.sent_at,
+ channel: channel.clone(),
+ data: types::MessageEvent {
+ id: row.id,
+ sender: login::Login {
+ id: row.sender_id,
+ name: row.sender_name,
+ },
+ body: row.body,
+ }
+ .into(),
})
.fetch_all(&mut *self.0)
.await?;
- Ok(messages)
+ Ok(events)
}
}
-
-#[derive(
- Debug,
- Eq,
- Ord,
- PartialEq,
- PartialOrd,
- Clone,
- Copy,
- serde::Serialize,
- serde::Deserialize,
- sqlx::Type,
-)]
-#[serde(transparent)]
-#[sqlx(transparent)]
-pub struct Sequence(i64);
diff --git a/src/events/repo/mod.rs b/src/events/repo/mod.rs
index 2ed3062..e216a50 100644
--- a/src/events/repo/mod.rs
+++ b/src/events/repo/mod.rs
@@ -1 +1 @@
-pub mod broadcast;
+pub mod message;
diff --git a/src/events/routes.rs b/src/events/routes.rs
index d901f9b..3f70dcd 100644
--- a/src/events/routes.rs
+++ b/src/events/routes.rs
@@ -1,8 +1,5 @@
-use std::collections::{BTreeMap, HashSet};
-
use axum::{
extract::State,
- http::StatusCode,
response::{
sse::{self, Sse},
IntoResponse, Response,
@@ -10,87 +7,32 @@ use axum::{
routing::get,
Router,
};
-use axum_extra::extract::Query;
-use futures::{
- future,
- stream::{self, Stream, StreamExt as _, TryStreamExt as _},
-};
+use futures::stream::{Stream, StreamExt as _};
-use super::{extract::LastEventId, repo::broadcast};
-use crate::{
- app::App,
- clock::RequestedAt,
- error::Internal,
- events::app::EventsError,
- repo::{channel, login::Login},
+use super::{
+ extract::LastEventId,
+ types::{self, ResumePoint},
};
+use crate::{app::App, clock::RequestedAt, error::Internal, repo::login::Login};
#[cfg(test)]
mod test;
-// For the purposes of event replay, an "event ID" is a vector of per-channel
-// sequence numbers. Replay will start with messages whose sequence number in
-// its channel is higher than the sequence in the event ID, or if the channel
-// is not listed in the event ID, then at the beginning.
-//
-// Using a sorted map ensures that there is a canonical representation for
-// each event ID.
-type EventId = BTreeMap<channel::Id, broadcast::Sequence>;
-
pub fn router() -> Router<App> {
Router::new().route("/api/events", get(events))
}
-#[derive(Clone, serde::Deserialize)]
-struct EventsQuery {
- #[serde(default, rename = "channel")]
- channels: HashSet<channel::Id>,
-}
-
async fn events(
State(app): State<App>,
- RequestedAt(now): RequestedAt,
+ RequestedAt(subscribed_at): RequestedAt,
_: Login, // requires auth, but doesn't actually care who you are
- last_event_id: Option<LastEventId<EventId>>,
- Query(query): Query<EventsQuery>,
-) -> Result<Events<impl Stream<Item = ReplayableEvent> + std::fmt::Debug>, ErrorResponse> {
+ last_event_id: Option<LastEventId<ResumePoint>>,
+) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> {
let resume_at = last_event_id
.map(LastEventId::into_inner)
.unwrap_or_default();
- let streams = stream::iter(query.channels)
- .then(|channel| {
- let app = app.clone();
- let resume_at = resume_at.clone();
- async move {
- let resume_at = resume_at.get(&channel).copied();
-
- let events = app
- .events()
- .subscribe(&channel, &now, resume_at)
- .await?
- .map(ChannelEvent::wrap(channel));
-
- Ok::<_, EventsError>(events)
- }
- })
- .try_collect::<Vec<_>>()
- .await
- // impl From would take more code; this is used once.
- .map_err(ErrorResponse)?;
-
- // We resume counting from the provided last-event-id mapping, rather than
- // starting from scratch, so that the events in a resumed stream contain
- // the full vector of channel IDs for their event IDs right off the bat,
- // even before any events are actually delivered.
- let stream = stream::select_all(streams).scan(resume_at, |sequences, event| {
- let (channel, sequence) = event.event_id();
- sequences.insert(channel, sequence);
-
- let event = ReplayableEvent(sequences.clone(), event);
-
- future::ready(Some(event))
- });
+ let stream = app.events().subscribe(&subscribed_at, resume_at).await?;
Ok(Events(stream))
}
@@ -100,7 +42,7 @@ struct Events<S>(S);
impl<S> IntoResponse for Events<S>
where
- S: Stream<Item = ReplayableEvent> + Send + 'static,
+ S: Stream<Item = types::ResumableEvent> + Send + 'static,
{
fn into_response(self) -> Response {
let Self(stream) = self;
@@ -111,51 +53,13 @@ where
}
}
-#[derive(Debug)]
-struct ErrorResponse(EventsError);
-
-impl IntoResponse for ErrorResponse {
- fn into_response(self) -> Response {
- let Self(error) = self;
- match error {
- not_found @ EventsError::ChannelNotFound(_) => {
- (StatusCode::NOT_FOUND, not_found.to_string()).into_response()
- }
- other => Internal::from(other).into_response(),
- }
- }
-}
-
-#[derive(Debug)]
-struct ReplayableEvent(EventId, ChannelEvent);
-
-#[derive(Debug, serde::Serialize)]
-struct ChannelEvent {
- channel: channel::Id,
- #[serde(flatten)]
- message: broadcast::Message,
-}
-
-impl ChannelEvent {
- fn wrap(channel: channel::Id) -> impl Fn(broadcast::Message) -> Self {
- move |message| Self {
- channel: channel.clone(),
- message,
- }
- }
-
- fn event_id(&self) -> (channel::Id, broadcast::Sequence) {
- (self.channel.clone(), self.message.sequence)
- }
-}
-
-impl TryFrom<ReplayableEvent> for sse::Event {
+impl TryFrom<types::ResumableEvent> for sse::Event {
type Error = serde_json::Error;
- fn try_from(value: ReplayableEvent) -> Result<Self, Self::Error> {
- let ReplayableEvent(id, data) = value;
+ fn try_from(value: types::ResumableEvent) -> Result<Self, Self::Error> {
+ let types::ResumableEvent(resume_at, data) = value;
- let id = serde_json::to_string(&id)?;
+ let id = serde_json::to_string(&resume_at)?;
let data = serde_json::to_string_pretty(&data)?;
let event = Self::default().id(id).data(data);
diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs
index 4412938..f289225 100644
--- a/src/events/routes/test.rs
+++ b/src/events/routes/test.rs
@@ -1,40 +1,15 @@
use axum::extract::State;
-use axum_extra::extract::Query;
use futures::{
future,
stream::{self, StreamExt as _},
};
use crate::{
- events::{app, routes},
- repo::channel::{self},
+ events::{routes, types},
test::fixtures::{self, future::Immediately as _},
};
#[tokio::test]
-async fn no_subscriptions() {
- // Set up the environment
-
- let app = fixtures::scratch_app().await;
- let subscriber = fixtures::login::create(&app).await;
-
- // Call the endpoint
-
- let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [].into(),
- };
- let routes::Events(mut events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("empty subscription");
-
- // Verify the structure of the response.
-
- assert!(events.next().immediately().await.is_none());
-}
-
-#[tokio::test]
async fn includes_historical_message() {
// Set up the environment
@@ -47,24 +22,19 @@ async fn includes_historical_message() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
- let routes::Events(mut events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to valid channel");
+ let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the structure of the response.
- let routes::ReplayableEvent(_, event) = events
+ let types::ResumableEvent(_, event) = events
.next()
.immediately()
.await
.expect("delivered stored message");
- assert_eq!(channel.id, event.channel);
- assert_eq!(message, event.message);
+ assert_eq!(message, event);
}
#[tokio::test]
@@ -78,68 +48,23 @@ async fn includes_live_message() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
- let routes::Events(mut events) = routes::events(
- State(app.clone()),
- subscribed_at,
- subscriber,
- None,
- Query(query),
- )
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(mut events) =
+ routes::events(State(app.clone()), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the semantics
let sender = fixtures::login::create(&app).await;
let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await;
- let routes::ReplayableEvent(_, event) = events
+ let types::ResumableEvent(_, event) = events
.next()
.immediately()
.await
.expect("delivered live message");
- assert_eq!(channel.id, event.channel);
- assert_eq!(message, event.message);
-}
-
-#[tokio::test]
-async fn excludes_other_channels() {
- // Set up the environment
-
- let app = fixtures::scratch_app().await;
- let subscribed_channel = fixtures::channel::create(&app).await;
- let unsubscribed_channel = fixtures::channel::create(&app).await;
- let sender = fixtures::login::create(&app).await;
- let message =
- fixtures::message::send(&app, &sender, &subscribed_channel, &fixtures::now()).await;
- fixtures::message::send(&app, &sender, &unsubscribed_channel, &fixtures::now()).await;
-
- // Call the endpoint
-
- let subscriber = fixtures::login::create(&app).await;
- let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [subscribed_channel.id.clone()].into(),
- };
- let routes::Events(mut events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to a valid channel");
-
- // Verify the semantics
-
- let routes::ReplayableEvent(_, event) = events
- .next()
- .immediately()
- .await
- .expect("delivered at least one message");
-
- assert_eq!(subscribed_channel.id, event.channel);
- assert_eq!(message, event.message);
+ assert_eq!(message, event);
}
#[tokio::test]
@@ -155,10 +80,11 @@ async fn includes_multiple_channels() {
];
let messages = stream::iter(channels)
- .then(|channel| async {
- let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await;
-
- (channel, message)
+ .then(|channel| {
+ let app = app.clone();
+ let sender = sender.clone();
+ let channel = channel.clone();
+ async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await }
})
.collect::<Vec<_>>()
.await;
@@ -167,17 +93,9 @@ async fn includes_multiple_channels() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: messages
- .iter()
- .map(|(channel, _)| &channel.id)
- .cloned()
- .collect(),
- };
- let routes::Events(events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to valid channels");
+ let routes::Events(events) = routes::events(State(app), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the structure of the response.
@@ -187,41 +105,14 @@ async fn includes_multiple_channels() {
.immediately()
.await;
- for (channel, message) in messages {
- assert!(events.iter().any(|routes::ReplayableEvent(_, event)| {
- event.channel == channel.id && event.message == message
- }));
+ for message in &messages {
+ assert!(events
+ .iter()
+ .any(|types::ResumableEvent(_, event)| { event == message }));
}
}
#[tokio::test]
-async fn nonexistent_channel() {
- // Set up the environment
-
- let app = fixtures::scratch_app().await;
- let channel = channel::Id::generate();
-
- // Call the endpoint
-
- let subscriber = fixtures::login::create(&app).await;
- let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.clone()].into(),
- };
- let routes::ErrorResponse(error) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect_err("subscribed to nonexistent channel");
-
- // Verify the structure of the response.
-
- assert!(matches!(
- error,
- app::EventsError::ChannelNotFound(error_channel) if error_channel == channel
- ));
-}
-
-#[tokio::test]
async fn sequential_messages() {
// Set up the environment
@@ -239,30 +130,24 @@ async fn sequential_messages() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
- let routes::Events(events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(events) = routes::events(State(app), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the structure of the response.
- let mut events = events.filter(|routes::ReplayableEvent(_, event)| {
- future::ready(messages.contains(&event.message))
- });
+ let mut events =
+ events.filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event)));
// Verify delivery in order
for message in &messages {
- let routes::ReplayableEvent(_, event) = events
+ let types::ResumableEvent(_, event) = events
.next()
.immediately()
.await
.expect("undelivered messages remaining");
- assert_eq!(channel.id, event.channel);
- assert_eq!(message, &event.message);
+ assert_eq!(message, &event);
}
}
@@ -285,42 +170,28 @@ async fn resumes_from() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
let resume_at = {
// First subscription
- let routes::Events(mut events) = routes::events(
- State(app.clone()),
- subscribed_at,
- subscriber.clone(),
- None,
- Query(query.clone()),
- )
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(mut events) =
+ routes::events(State(app.clone()), subscribed_at, subscriber.clone(), None)
+ .await
+ .expect("subscribe never fails");
- let routes::ReplayableEvent(id, event) =
+ let types::ResumableEvent(last_event_id, event) =
events.next().immediately().await.expect("delivered events");
- assert_eq!(channel.id, event.channel);
- assert_eq!(initial_message, event.message);
+ assert_eq!(initial_message, event);
- id
+ last_event_id
};
// Resume after disconnect
let reconnect_at = fixtures::now();
- let routes::Events(resumed) = routes::events(
- State(app),
- reconnect_at,
- subscriber,
- Some(resume_at.into()),
- Query(query),
- )
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(resumed) =
+ routes::events(State(app), reconnect_at, subscriber, Some(resume_at.into()))
+ .await
+ .expect("subscribe never fails");
// Verify the structure of the response.
@@ -330,11 +201,10 @@ async fn resumes_from() {
.immediately()
.await;
- for message in later_messages {
- assert!(events.iter().any(
- |routes::ReplayableEvent(_, event)| event.channel == channel.id
- && event.message == message
- ));
+ for message in &later_messages {
+ assert!(events
+ .iter()
+ .any(|types::ResumableEvent(_, event)| event == message));
}
}
@@ -365,9 +235,6 @@ async fn serial_resume() {
// Call the endpoint
let subscriber = fixtures::login::create(&app).await;
- let query = routes::EventsQuery {
- channels: [channel_a.id.clone(), channel_b.id.clone()].into(),
- };
let resume_at = {
let initial_messages = [
@@ -377,15 +244,10 @@ async fn serial_resume() {
// First subscription
let subscribed_at = fixtures::now();
- let routes::Events(events) = routes::events(
- State(app.clone()),
- subscribed_at,
- subscriber.clone(),
- None,
- Query(query.clone()),
- )
- .await
- .expect("subscribed to a valid channel");
+ let routes::Events(events) =
+ routes::events(State(app.clone()), subscribed_at, subscriber.clone(), None)
+ .await
+ .expect("subscribe never fails");
let events = events
.take(initial_messages.len())
@@ -393,13 +255,13 @@ async fn serial_resume() {
.immediately()
.await;
- for message in initial_messages {
+ for message in &initial_messages {
assert!(events
.iter()
- .any(|routes::ReplayableEvent(_, event)| event.message == message));
+ .any(|types::ResumableEvent(_, event)| event == message));
}
- let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty");
+ let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty");
id.to_owned()
};
@@ -421,10 +283,9 @@ async fn serial_resume() {
resubscribed_at,
subscriber.clone(),
Some(resume_at.into()),
- Query(query.clone()),
)
.await
- .expect("subscribed to a valid channel");
+ .expect("subscribe never fails");
let events = events
.take(resume_messages.len())
@@ -432,13 +293,13 @@ async fn serial_resume() {
.immediately()
.await;
- for message in resume_messages {
+ for message in &resume_messages {
assert!(events
.iter()
- .any(|routes::ReplayableEvent(_, event)| event.message == message));
+ .any(|types::ResumableEvent(_, event)| event == message));
}
- let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty");
+ let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty");
id.to_owned()
};
@@ -460,10 +321,9 @@ async fn serial_resume() {
resubscribed_at,
subscriber.clone(),
Some(resume_at.into()),
- Query(query.clone()),
)
.await
- .expect("subscribed to a valid channel");
+ .expect("subscribe never fails");
let events = events
.take(final_messages.len())
@@ -473,10 +333,10 @@ async fn serial_resume() {
// This set of messages, in particular, _should not_ include any prior
// messages from `initial_messages` or `resume_messages`.
- for message in final_messages {
+ for message in &final_messages {
assert!(events
.iter()
- .any(|routes::ReplayableEvent(_, event)| event.message == message));
+ .any(|types::ResumableEvent(_, event)| event == message));
}
};
}
@@ -495,22 +355,18 @@ async fn removes_expired_messages() {
let subscriber = fixtures::login::create(&app).await;
let subscribed_at = fixtures::now();
- let query = routes::EventsQuery {
- channels: [channel.id.clone()].into(),
- };
- let routes::Events(mut events) =
- routes::events(State(app), subscribed_at, subscriber, None, Query(query))
- .await
- .expect("subscribed to valid channel");
+
+ let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None)
+ .await
+ .expect("subscribe never fails");
// Verify the semantics
- let routes::ReplayableEvent(_, event) = events
+ let types::ResumableEvent(_, event) = events
.next()
.immediately()
.await
.expect("delivered messages");
- assert_eq!(channel.id, event.channel);
- assert_eq!(message, event.message);
+ assert_eq!(message, event);
}
diff --git a/src/events/types.rs b/src/events/types.rs
new file mode 100644
index 0000000..6747afc
--- /dev/null
+++ b/src/events/types.rs
@@ -0,0 +1,99 @@
+use std::collections::BTreeMap;
+
+use crate::{
+ clock::DateTime,
+ repo::{
+ channel::{self, Channel},
+ login::Login,
+ message,
+ },
+};
+
+#[derive(
+ Debug,
+ Eq,
+ Ord,
+ PartialEq,
+ PartialOrd,
+ Clone,
+ Copy,
+ serde::Serialize,
+ serde::Deserialize,
+ sqlx::Type,
+)]
+#[serde(transparent)]
+#[sqlx(transparent)]
+pub struct Sequence(i64);
+
+impl Sequence {
+ pub fn next(self) -> Self {
+ let Self(current) = self;
+ Self(current + 1)
+ }
+}
+
+// For the purposes of event replay, an "event ID" is a vector of per-channel
+// sequence numbers. Replay will start with messages whose sequence number in
+// its channel is higher than the sequence in the event ID, or if the channel
+// is not listed in the event ID, then at the beginning.
+//
+// Using a sorted map ensures that there is a canonical representation for
+// each event ID.
+#[derive(Clone, Debug, Default, PartialEq, PartialOrd, serde::Deserialize, serde::Serialize)]
+#[serde(transparent)]
+pub struct ResumePoint(BTreeMap<channel::Id, Sequence>);
+
+impl ResumePoint {
+ pub fn singleton(channel: &channel::Id, sequence: Sequence) -> Self {
+ let mut vector = Self::default();
+ vector.advance(channel, sequence);
+ vector
+ }
+
+ pub fn advance(&mut self, channel: &channel::Id, sequence: Sequence) {
+ let Self(elements) = self;
+ elements.insert(channel.clone(), sequence);
+ }
+
+ pub fn get(&self, channel: &channel::Id) -> Option<Sequence> {
+ let Self(elements) = self;
+ elements.get(channel).copied()
+ }
+}
+#[derive(Clone, Debug)]
+pub struct ResumableEvent(pub ResumePoint, pub ChannelEvent);
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct ChannelEvent {
+ #[serde(skip)]
+ pub sequence: Sequence,
+ pub at: DateTime,
+ pub channel: Channel,
+ #[serde(flatten)]
+ pub data: ChannelEventData,
+}
+
+impl ChannelEvent {
+ pub fn sequence(&self) -> ResumePoint {
+ ResumePoint::singleton(&self.channel.id, self.sequence)
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ChannelEventData {
+ Message(MessageEvent),
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct MessageEvent {
+ pub id: message::Id,
+ pub sender: Login,
+ pub body: String,
+}
+
+impl From<MessageEvent> for ChannelEventData {
+ fn from(message: MessageEvent) -> Self {
+ Self::Message(message)
+ }
+}
diff --git a/src/repo/channel.rs b/src/repo/channel.rs
index 0186413..d223dab 100644
--- a/src/repo/channel.rs
+++ b/src/repo/channel.rs
@@ -16,7 +16,7 @@ impl<'c> Provider for Transaction<'c, Sqlite> {
pub struct Channels<'t>(&'t mut SqliteConnection);
-#[derive(Debug, Eq, PartialEq, serde::Serialize)]
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
pub struct Channel {
pub id: Id,
pub name: String,
diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs
index 33feeae..bfca8cd 100644
--- a/src/test/fixtures/message.rs
+++ b/src/test/fixtures/message.rs
@@ -3,7 +3,7 @@ use faker_rand::lorem::Paragraphs;
use crate::{
app::App,
clock::RequestedAt,
- events::repo::broadcast,
+ events::types,
repo::{channel::Channel, login::Login},
};
@@ -12,7 +12,7 @@ pub async fn send(
login: &Login,
channel: &Channel,
sent_at: &RequestedAt,
-) -> broadcast::Message {
+) -> types::ChannelEvent {
let body = propose();
app.events()
diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs
index a42dba5..450fbec 100644
--- a/src/test/fixtures/mod.rs
+++ b/src/test/fixtures/mod.rs
@@ -13,8 +13,6 @@ pub async fn scratch_app() -> App {
.await
.expect("setting up in-memory sqlite database");
App::from(pool)
- .await
- .expect("creating an app from a fresh, in-memory database")
}
pub fn now() -> RequestedAt {