diff options
Diffstat (limited to 'src/events')
| -rw-r--r-- | src/events/app.rs | 12 | ||||
| -rw-r--r-- | src/events/broadcaster.rs | 75 | ||||
| -rw-r--r-- | src/events/routes.rs | 5 | ||||
| -rw-r--r-- | src/events/routes/test.rs | 57 |
4 files changed, 62 insertions, 87 deletions
diff --git a/src/events/app.rs b/src/events/app.rs index 0cdc641..db7f430 100644 --- a/src/events/app.rs +++ b/src/events/app.rs @@ -24,12 +24,12 @@ use crate::{ pub struct Events<'a> { db: &'a SqlitePool, - broadcaster: &'a Broadcaster, + events: &'a Broadcaster, } impl<'a> Events<'a> { - pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self { - Self { db, broadcaster } + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } } pub async fn send( @@ -51,7 +51,7 @@ impl<'a> Events<'a> { .await?; tx.commit().await?; - self.broadcaster.broadcast(&event); + self.events.broadcast(&event); Ok(event) } @@ -75,7 +75,7 @@ impl<'a> Events<'a> { tx.commit().await?; for event in events { - self.broadcaster.broadcast(&event); + self.events.broadcast(&event); } Ok(()) @@ -101,7 +101,7 @@ impl<'a> Events<'a> { // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. - let live_messages = self.broadcaster.subscribe(); + let live_messages = self.events.subscribe(); let mut replays = BTreeMap::new(); let mut resume_live_at = resume_at.clone(); diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs index 9697c0a..6b664cb 100644 --- a/src/events/broadcaster.rs +++ b/src/events/broadcaster.rs @@ -1,74 +1,3 @@ -use std::sync::{Arc, Mutex}; +use crate::{broadcast, events::types}; -use futures::{future, stream::StreamExt as _, Stream}; -use tokio::sync::broadcast::{channel, Sender}; -use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; - -use crate::events::types; - -// Clones will share the same sender. -#[derive(Clone)] -pub struct Broadcaster { - // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's - // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that - // lock it must be sync. - senders: Arc<Mutex<Sender<types::ChannelEvent>>>, -} - -impl Default for Broadcaster { - fn default() -> Self { - let sender = Self::make_sender(); - - Self { - senders: Arc::new(Mutex::new(sender)), - } - } -} - -impl Broadcaster { - // panic: if ``message.channel.id`` has not been previously registered, - // and was not part of the initial set of channels. - pub fn broadcast(&self, message: &types::ChannelEvent) { - let tx = self.sender(); - - // Per the Tokio docs, the returned error is only used to indicate that - // there are no receivers. In this use case, that's fine; a lack of - // listening consumers (chat clients) when a message is sent isn't an - // error. - // - // The successful return value, which includes the number of active - // receivers, also isn't that interesting to us. - let _ = tx.send(message.clone()); - } - - // panic: if ``channel`` has not been previously registered, and was not - // part of the initial set of channels. - pub fn subscribe(&self) -> impl Stream<Item = types::ChannelEvent> + std::fmt::Debug { - let rx = self.sender().subscribe(); - - BroadcastStream::from(rx).scan((), |(), r| { - future::ready(match r { - Ok(event) => Some(event), - // Stop the stream here. This will disconnect SSE clients - // (see `routes.rs`), who will then resume from - // `Last-Event-ID`, allowing them to catch up by reading - // the skipped messages from the database. - // - // See also: - // <https://users.rust-lang.org/t/taking-from-stream-while-ok/48854> - Err(BroadcastStreamRecvError::Lagged(_)) => None, - }) - }) - } - - fn sender(&self) -> Sender<types::ChannelEvent> { - self.senders.lock().unwrap().clone() - } - - fn make_sender() -> Sender<types::ChannelEvent> { - // Queue depth of 16 chosen entirely arbitrarily. Don't read too much - // into it. - let (tx, _) = channel(16); - tx - } -} +pub type Broadcaster = broadcast::Broadcaster<types::ChannelEvent>; diff --git a/src/events/routes.rs b/src/events/routes.rs index 89c942c..ec9dae2 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -13,7 +13,7 @@ use super::{ extract::LastEventId, types::{self, ResumePoint}, }; -use crate::{app::App, error::Internal, repo::login::Login}; +use crate::{app::App, error::Internal, login::extract::Identity}; #[cfg(test)] mod test; @@ -24,7 +24,7 @@ pub fn router() -> Router<App> { async fn events( State(app): State<App>, - _: Login, // requires auth, but doesn't actually care who you are + identity: Identity, last_event_id: Option<LastEventId<ResumePoint>>, ) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> { let resume_at = last_event_id @@ -32,6 +32,7 @@ async fn events( .unwrap_or_default(); let stream = app.events().subscribe(resume_at).await?; + let stream = app.logins().limit_stream(identity.token, stream); Ok(Events(stream)) } diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs index a6e2275..0b62b5b 100644 --- a/src/events/routes/test.rs +++ b/src/events/routes/test.rs @@ -20,7 +20,8 @@ async fn includes_historical_message() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let routes::Events(events) = routes::events(State(app), subscriber, None) .await .expect("subscribe never fails"); @@ -46,7 +47,8 @@ async fn includes_live_message() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) .await .expect("subscribe never fails"); @@ -90,7 +92,8 @@ async fn includes_multiple_channels() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let routes::Events(events) = routes::events(State(app), subscriber, None) .await .expect("subscribe never fails"); @@ -127,7 +130,8 @@ async fn sequential_messages() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let routes::Events(events) = routes::events(State(app), subscriber, None) .await .expect("subscribe never fails"); @@ -166,7 +170,8 @@ async fn resumes_from() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let resume_at = { // First subscription @@ -232,7 +237,8 @@ async fn serial_resume() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let resume_at = { let initial_messages = [ @@ -335,3 +341,42 @@ async fn serial_resume() { } }; } + +#[tokio::test] +async fn terminates_on_token_expiry() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + // Subscribe via the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = + fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; + let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) + .await + .expect("subscribe never fails"); + + // Verify the resulting stream's behaviour + + app.logins() + .expire(&fixtures::now()) + .await + .expect("expiring tokens succeeds"); + + // These should not be delivered. + let messages = [ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + assert!(events + .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))) + .next() + .immediately() + .await + .is_none()); +} |
