From d171a258ad2119e39cb715f8800031fff16967dc Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Tue, 1 Oct 2024 22:43:18 -0400 Subject: Provide a resume point to bridge clients from state snapshots to the event sequence. --- src/channel/app.rs | 6 ++-- src/channel/routes.rs | 15 ++++++-- src/channel/routes/test/list.rs | 7 ++-- src/channel/routes/test/on_create.rs | 2 +- src/events/routes.rs | 11 +++++- src/events/routes/test.rs | 67 +++++++++++++++++++++++++----------- src/login/app.rs | 9 +++++ src/login/routes.rs | 9 +++-- src/login/routes/test/boot.rs | 7 +++- src/repo/channel.rs | 7 +++- src/repo/sequence.rs | 52 ++++++++++++++++++++-------- 11 files changed, 142 insertions(+), 50 deletions(-) (limited to 'src') diff --git a/src/channel/app.rs b/src/channel/app.rs index 88f4170..d89e733 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -6,7 +6,7 @@ use crate::{ events::{broadcaster::Broadcaster, types::ChannelEvent}, repo::{ channel::{Channel, Provider as _}, - sequence::Provider as _, + sequence::{Provider as _, Sequence}, }, }; @@ -36,9 +36,9 @@ impl<'a> Channels<'a> { Ok(channel) } - pub async fn all(&self) -> Result, InternalError> { + pub async fn all(&self, resume_point: Option) -> Result, InternalError> { let mut tx = self.db.begin().await?; - let channels = tx.channels().all().await?; + let channels = tx.channels().all(resume_point).await?; tx.commit().await?; Ok(channels) diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 1f8db5a..067d213 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -5,6 +5,7 @@ use axum::{ routing::{get, post}, Router, }; +use axum_extra::extract::Query; use super::app; use crate::{ @@ -15,6 +16,7 @@ use crate::{ repo::{ channel::{self, Channel}, login::Login, + sequence::Sequence, }, }; @@ -28,8 +30,17 @@ pub fn router() -> Router { .route("/api/channels/:channel", post(on_send)) } -async fn list(State(app): State, _: Login) -> Result { - let channels = app.channels().all().await?; +#[derive(Default, serde::Deserialize)] +struct ListQuery { + resume_point: Option, +} + +async fn list( + State(app): State, + _: Login, + Query(query): Query, +) -> Result { + let channels = app.channels().all(query.resume_point).await?; let response = Channels(channels); Ok(response) diff --git a/src/channel/routes/test/list.rs b/src/channel/routes/test/list.rs index bc94024..f15a53c 100644 --- a/src/channel/routes/test/list.rs +++ b/src/channel/routes/test/list.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use axum_extra::extract::Query; use crate::{channel::routes, test::fixtures}; @@ -11,7 +12,7 @@ async fn empty_list() { // Call the endpoint - let routes::Channels(channels) = routes::list(State(app), viewer) + let routes::Channels(channels) = routes::list(State(app), viewer, Query::default()) .await .expect("always succeeds"); @@ -30,7 +31,7 @@ async fn one_channel() { // Call the endpoint - let routes::Channels(channels) = routes::list(State(app), viewer) + let routes::Channels(channels) = routes::list(State(app), viewer, Query::default()) .await .expect("always succeeds"); @@ -52,7 +53,7 @@ async fn multiple_channels() { // Call the endpoint - let routes::Channels(response_channels) = routes::list(State(app), viewer) + let routes::Channels(response_channels) = routes::list(State(app), viewer, Query::default()) .await .expect("always succeeds"); diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index 5deb88a..72980ac 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -33,7 +33,7 @@ async fn new_channel() { // Verify the semantics - let channels = app.channels().all().await.expect("always succeeds"); + let channels = app.channels().all(None).await.expect("always succeeds"); assert!(channels.contains(&response_channel)); let mut events = app diff --git a/src/events/routes.rs b/src/events/routes.rs index e3a959f..d81c7fb 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -7,6 +7,7 @@ use axum::{ routing::get, Router, }; +use axum_extra::extract::Query; use futures::stream::{Stream, StreamExt as _}; use super::{extract::LastEventId, types}; @@ -24,12 +25,20 @@ pub fn router() -> Router { Router::new().route("/api/events", get(events)) } +#[derive(Default, serde::Deserialize)] +struct EventsQuery { + resume_point: Option, +} + async fn events( State(app): State, identity: Identity, last_event_id: Option>, + Query(query): Query, ) -> Result + std::fmt::Debug>, EventsError> { - let resume_at = last_event_id.map(LastEventId::into_inner); + let resume_at = last_event_id + .map(LastEventId::into_inner) + .or(query.resume_point); let stream = app.events().subscribe(resume_at).await?; let stream = app.logins().limit_stream(identity.token, stream).await?; diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs index 1cfca4f..11f01b8 100644 --- a/src/events/routes/test.rs +++ b/src/events/routes/test.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use axum_extra::extract::Query; use futures::{ future, stream::{self, StreamExt as _}, @@ -22,7 +23,7 @@ async fn includes_historical_message() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None) + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) .await .expect("subscribe never fails"); @@ -49,9 +50,10 @@ async fn includes_live_message() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); // Verify the semantics @@ -94,7 +96,7 @@ async fn includes_multiple_channels() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None) + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) .await .expect("subscribe never fails"); @@ -130,7 +132,7 @@ async fn sequential_messages() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None) + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) .await .expect("subscribe never fails"); @@ -172,9 +174,14 @@ async fn resumes_from() { let resume_at = { // First subscription - let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); let event = events .filter(fixtures::filter::messages()) @@ -189,9 +196,14 @@ async fn resumes_from() { }; // Resume after disconnect - let routes::Events(resumed) = routes::events(State(app), subscriber, Some(resume_at.into())) - .await - .expect("subscribe never fails"); + let routes::Events(resumed) = routes::events( + State(app), + subscriber, + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); // Verify the structure of the response. @@ -242,9 +254,14 @@ async fn serial_resume() { ]; // First subscription - let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); let events = events .filter(fixtures::filter::messages()) @@ -277,6 +294,7 @@ async fn serial_resume() { State(app.clone()), subscriber.clone(), Some(resume_at.into()), + Query::default(), ) .await .expect("subscribe never fails"); @@ -312,6 +330,7 @@ async fn serial_resume() { State(app.clone()), subscriber.clone(), Some(resume_at.into()), + Query::default(), ) .await .expect("subscribe never fails"); @@ -345,9 +364,10 @@ async fn terminates_on_token_expiry() { let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; - let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); // Verify the resulting stream's behaviour @@ -387,9 +407,14 @@ async fn terminates_on_logout() { let subscriber = fixtures::identity::from_token(&app, &subscriber_token, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); // Verify the resulting stream's behaviour diff --git a/src/login/app.rs b/src/login/app.rs index 95f0a07..f1dffb9 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -13,6 +13,7 @@ use crate::{ repo::{ error::NotFound as _, login::{Login, Provider as _}, + sequence::{Provider as _, Sequence}, token::{self, Provider as _}, }, }; @@ -27,6 +28,14 @@ impl<'a> Logins<'a> { Self { db, logins } } + pub async fn boot_point(&self) -> Result { + let mut tx = self.db.begin().await?; + let sequence = tx.sequence().current().await?; + tx.commit().await?; + + Ok(sequence) + } + pub async fn login( &self, name: &str, diff --git a/src/login/routes.rs b/src/login/routes.rs index d7cb9b1..ef75871 100644 --- a/src/login/routes.rs +++ b/src/login/routes.rs @@ -26,13 +26,18 @@ pub fn router() -> Router { .route("/api/auth/logout", post(on_logout)) } -async fn boot(login: Login) -> Boot { - Boot { login } +async fn boot(State(app): State, login: Login) -> Result { + let resume_point = app.logins().boot_point().await?; + Ok(Boot { + login, + resume_point: resume_point.to_string(), + }) } #[derive(serde::Serialize)] struct Boot { login: Login, + resume_point: String, } impl IntoResponse for Boot { diff --git a/src/login/routes/test/boot.rs b/src/login/routes/test/boot.rs index dee554f..9655354 100644 --- a/src/login/routes/test/boot.rs +++ b/src/login/routes/test/boot.rs @@ -1,9 +1,14 @@ +use axum::extract::State; + use crate::{login::routes, test::fixtures}; #[tokio::test] async fn returns_identity() { + let app = fixtures::scratch_app().await; let login = fixtures::login::fictitious(); - let response = routes::boot(login.clone()).await; + let response = routes::boot(State(app), login.clone()) + .await + .expect("boot always succeeds"); assert_eq!(login, response.login); } diff --git a/src/repo/channel.rs b/src/repo/channel.rs index efc2ced..ad42710 100644 --- a/src/repo/channel.rs +++ b/src/repo/channel.rs @@ -82,7 +82,10 @@ impl<'c> Channels<'c> { Ok(channel) } - pub async fn all(&mut self) -> Result, sqlx::Error> { + pub async fn all( + &mut self, + resume_point: Option, + ) -> Result, sqlx::Error> { let channels = sqlx::query_as!( Channel, r#" @@ -92,8 +95,10 @@ impl<'c> Channels<'c> { created_at as "created_at: DateTime", created_sequence as "created_sequence: Sequence" from channel + where coalesce(created_sequence <= $1, true) order by channel.name "#, + resume_point, ) .fetch_all(&mut *self.0) .await?; diff --git a/src/repo/sequence.rs b/src/repo/sequence.rs index 8fe9dab..c47b41c 100644 --- a/src/repo/sequence.rs +++ b/src/repo/sequence.rs @@ -1,3 +1,5 @@ +use std::fmt; + use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; pub trait Provider { @@ -10,6 +12,37 @@ impl<'c> Provider for Transaction<'c, Sqlite> { } } +pub struct Sequences<'t>(&'t mut SqliteConnection); + +impl<'c> Sequences<'c> { + pub async fn next(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + update event_sequence + set last_value = last_value + 1 + returning last_value as "next_value: Sequence" + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } + + pub async fn current(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + select last_value as "last_value: Sequence" + from event_sequence + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } +} + #[derive( Clone, Copy, @@ -26,20 +59,9 @@ impl<'c> Provider for Transaction<'c, Sqlite> { #[sqlx(transparent)] pub struct Sequence(i64); -pub struct Sequences<'t>(&'t mut SqliteConnection); - -impl<'c> Sequences<'c> { - pub async fn next(&mut self) -> Result { - let next = sqlx::query_scalar!( - r#" - update event_sequence - set last_value = last_value + 1 - returning last_value as "next_value: Sequence" - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) +impl fmt::Display for Sequence { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(value) = self; + value.fmt(f) } } -- cgit v1.2.3