From b8392a5fe824eff46f912a58885546e7b0f37e6f Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Tue, 1 Oct 2024 22:30:04 -0400 Subject: Track event sequences globally, not per channel. Per-channel event sequences were a cute idea, but it made reasoning about event resumption much, much harder (case in point: recovering the order of events in a partially-ordered collection is quadratic, since it's basically graph sort). The minor overhead of a global sequence number is likely tolerable, and this simplifies both the API and the internals. --- src/channel/app.rs | 14 ++++--- src/channel/routes/test/on_create.rs | 5 +-- src/channel/routes/test/on_send.rs | 4 +- src/events/app.rs | 76 ++++++++++++---------------------- src/events/repo/message.rs | 79 ++++++++++++++++-------------------- src/events/routes.rs | 26 +++++------- src/events/routes/test.rs | 47 ++++++++------------- src/events/types.rs | 79 ++---------------------------------- src/repo/channel.rs | 53 ++++++++++++++++++------ src/repo/mod.rs | 1 + src/repo/sequence.rs | 45 ++++++++++++++++++++ src/test/fixtures/filter.rs | 12 ++---- 12 files changed, 198 insertions(+), 243 deletions(-) create mode 100644 src/repo/sequence.rs (limited to 'src') diff --git a/src/channel/app.rs b/src/channel/app.rs index 70cda47..88f4170 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -3,8 +3,11 @@ use sqlx::sqlite::SqlitePool; use crate::{ clock::DateTime, - events::{broadcaster::Broadcaster, repo::message::Provider as _, types::ChannelEvent}, - repo::channel::{Channel, Provider as _}, + events::{broadcaster::Broadcaster, types::ChannelEvent}, + repo::{ + channel::{Channel, Provider as _}, + sequence::Provider as _, + }, }; pub struct Channels<'a> { @@ -19,9 +22,10 @@ impl<'a> Channels<'a> { pub async fn create(&self, name: &str, created_at: &DateTime) -> Result { let mut tx = self.db.begin().await?; + let created_sequence = tx.sequence().next().await?; let channel = tx .channels() - .create(name, created_at) + .create(name, created_at, created_sequence) .await .map_err(|err| CreateError::from_duplicate_name(err, name))?; tx.commit().await?; @@ -49,10 +53,10 @@ impl<'a> Channels<'a> { let mut events = Vec::with_capacity(expired.len()); for channel in expired { - let sequence = tx.message_events().assign_sequence(&channel).await?; + let deleted_sequence = tx.sequence().next().await?; let event = tx .channels() - .delete_expired(&channel, sequence, relative_to) + .delete(&channel, relative_to, deleted_sequence) .await?; events.push(event); } diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index e2610a5..5deb88a 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -38,18 +38,17 @@ async fn new_channel() { let mut events = app .events() - .subscribe(types::ResumePoint::default()) + .subscribe(None) .await .expect("subscribing never fails") .filter(fixtures::filter::created()); - let types::ResumableEvent(_, event) = events + let event = events .next() .immediately() .await .expect("creation event published"); - assert_eq!(types::Sequence::default(), event.sequence); assert!(matches!( event.data, types::ChannelEventData::Created(event) diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 233518b..d37ed21 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -43,7 +43,7 @@ async fn messages_in_order() { let events = app .events() - .subscribe(types::ResumePoint::default()) + .subscribe(None) .await .expect("subscribing to a valid channel") .filter(fixtures::filter::messages()) @@ -51,7 +51,7 @@ async fn messages_in_order() { let events = events.collect::>().immediately().await; - for ((sent_at, message), types::ResumableEvent(_, event)) in requests.into_iter().zip(events) { + for ((sent_at, message), event) in requests.into_iter().zip(events) { assert_eq!(*sent_at, event.at); assert!(matches!( event.data, diff --git a/src/events/app.rs b/src/events/app.rs index db7f430..c15f11e 100644 --- a/src/events/app.rs +++ b/src/events/app.rs @@ -1,5 +1,3 @@ -use std::collections::BTreeMap; - use chrono::TimeDelta; use futures::{ future, @@ -11,7 +9,7 @@ use sqlx::sqlite::SqlitePool; use super::{ broadcaster::Broadcaster, repo::message::Provider as _, - types::{self, ChannelEvent, ResumePoint}, + types::{self, ChannelEvent}, }; use crate::{ clock::DateTime, @@ -19,6 +17,7 @@ use crate::{ channel::{self, Provider as _}, error::NotFound as _, login::Login, + sequence::{Provider as _, Sequence}, }, }; @@ -45,9 +44,10 @@ impl<'a> Events<'a> { .by_id(channel) .await .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; + let sent_sequence = tx.sequence().next().await?; let event = tx .message_events() - .create(login, &channel, body, sent_at) + .create(login, &channel, sent_at, sent_sequence, body) .await?; tx.commit().await?; @@ -64,10 +64,10 @@ impl<'a> Events<'a> { let mut events = Vec::with_capacity(expired.len()); for (channel, message) in expired { - let sequence = tx.message_events().assign_sequence(&channel).await?; + let deleted_sequence = tx.sequence().next().await?; let event = tx .message_events() - .delete_expired(&channel, &message, sequence, relative_to) + .delete(&channel, &message, relative_to, deleted_sequence) .await?; events.push(event); } @@ -83,42 +83,30 @@ impl<'a> Events<'a> { pub async fn subscribe( &self, - resume_at: ResumePoint, - ) -> Result + std::fmt::Debug, sqlx::Error> { - let mut tx = self.db.begin().await?; - let channels = tx.channels().all().await?; - - let created_events = { - let resume_at = resume_at.clone(); - let channels = channels.clone(); - stream::iter( - channels - .into_iter() - .map(ChannelEvent::created) - .filter(move |event| resume_at.not_after(event)), - ) - }; - + resume_at: Option, + ) -> Result + std::fmt::Debug, sqlx::Error> { // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. let live_messages = self.events.subscribe(); - let mut replays = BTreeMap::new(); - let mut resume_live_at = resume_at.clone(); - for channel in channels { - let replay = tx - .message_events() - .replay(&channel, resume_at.get(&channel.id)) - .await?; + let mut tx = self.db.begin().await?; + let channels = tx.channels().replay(resume_at).await?; - if let Some(last) = replay.last() { - resume_live_at.advance(last); - } + let channel_events = channels + .into_iter() + .map(ChannelEvent::created) + .filter(move |event| resume_at.map_or(true, |resume_at| event.sequence > resume_at)); - replays.insert(channel.id.clone(), replay); - } + let message_events = tx.message_events().replay(resume_at).await?; + + let mut replay_events = channel_events + .into_iter() + .chain(message_events.into_iter()) + .collect::>(); + replay_events.sort_by_key(|event| event.sequence); + let resume_live_at = replay_events.last().map(|event| event.sequence); - let replay = stream::select_all(replays.into_values().map(stream::iter)); + let replay = stream::iter(replay_events); // no skip_expired or resume transforms for stored_messages, as it's // constructed not to contain messages meeting either criterion. @@ -132,25 +120,13 @@ impl<'a> Events<'a> { // stored_messages. .filter(Self::resume(resume_live_at)); - Ok(created_events.chain(replay).chain(live_messages).scan( - resume_at, - |resume_point, event| { - match event.data { - types::ChannelEventData::Deleted(_) => resume_point.forget(&event), - _ => resume_point.advance(&event), - } - - let event = types::ResumableEvent(resume_point.clone(), event); - - future::ready(Some(event)) - }, - )) + Ok(replay.chain(live_messages)) } fn resume( - resume_at: ResumePoint, + resume_at: Option, ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { - move |event| future::ready(resume_at.not_after(event)) + move |event| future::ready(resume_at < Some(event.sequence)) } } diff --git a/src/events/repo/message.rs b/src/events/repo/message.rs index f8bae2b..3237553 100644 --- a/src/events/repo/message.rs +++ b/src/events/repo/message.rs @@ -2,11 +2,12 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ clock::DateTime, - events::types::{self, Sequence}, + events::types, repo::{ channel::{self, Channel}, login::{self, Login}, message::{self, Message}, + sequence::Sequence, }, }; @@ -27,34 +28,33 @@ impl<'c> Events<'c> { &mut self, sender: &Login, channel: &Channel, - body: &str, sent_at: &DateTime, + sent_sequence: Sequence, + body: &str, ) -> Result { - let sequence = self.assign_sequence(channel).await?; - let id = message::Id::generate(); let message = sqlx::query!( r#" insert into message - (id, channel, sequence, sender, body, sent_at) + (id, channel, sender, sent_at, sent_sequence, body) values ($1, $2, $3, $4, $5, $6) returning id as "id: message::Id", - sequence as "sequence: Sequence", sender as "sender: login::Id", - body, - sent_at as "sent_at: DateTime" + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence", + body "#, id, channel.id, - sequence, sender.id, - body, sent_at, + sent_sequence, + body, ) .map(|row| types::ChannelEvent { - sequence: row.sequence, + sequence: row.sent_sequence, at: row.sent_at, data: types::MessageEvent { channel: channel.clone(), @@ -72,28 +72,12 @@ impl<'c> Events<'c> { Ok(message) } - pub async fn assign_sequence(&mut self, channel: &Channel) -> Result { - let next = sqlx::query_scalar!( - r#" - update channel - set last_sequence = last_sequence + 1 - where id = $1 - returning last_sequence as "next_sequence: Sequence" - "#, - channel.id, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) - } - - pub async fn delete_expired( + pub async fn delete( &mut self, channel: &Channel, message: &message::Id, - sequence: Sequence, deleted_at: &DateTime, + deleted_sequence: Sequence, ) -> Result { sqlx::query_scalar!( r#" @@ -107,7 +91,7 @@ impl<'c> Events<'c> { .await?; Ok(types::ChannelEvent { - sequence, + sequence: deleted_sequence, at: *deleted_at, data: types::MessageDeletedEvent { channel: channel.clone(), @@ -127,6 +111,7 @@ impl<'c> Events<'c> { channel.id as "channel_id: channel::Id", channel.name as "channel_name", channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", message.id as "message: message::Id" from message join channel on message.channel = channel.id @@ -141,6 +126,7 @@ impl<'c> Events<'c> { id: row.channel_id, name: row.channel_name, created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, }, row.message, ) @@ -153,32 +139,39 @@ impl<'c> Events<'c> { pub async fn replay( &mut self, - channel: &Channel, resume_at: Option, ) -> Result, sqlx::Error> { let events = sqlx::query!( r#" select message.id as "id: message::Id", - sequence as "sequence: Sequence", - login.id as "sender_id: login::Id", - login.name as sender_name, - message.body, - message.sent_at as "sent_at: DateTime" + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", + sender.id as "sender_id: login::Id", + sender.name as sender_name, + message.sent_at as "sent_at: DateTime", + message.sent_sequence as "sent_sequence: Sequence", + message.body from message - join login on message.sender = login.id - where channel = $1 - and coalesce(sequence > $2, true) - order by sequence asc + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where coalesce(message.sent_sequence > $1, true) + order by sent_sequence asc "#, - channel.id, resume_at, ) .map(|row| types::ChannelEvent { - sequence: row.sequence, + sequence: row.sent_sequence, at: row.sent_at, data: types::MessageEvent { - channel: channel.clone(), + channel: Channel { + id: row.channel_id, + name: row.channel_name, + created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, + }, sender: login::Login { id: row.sender_id, name: row.sender_name, diff --git a/src/events/routes.rs b/src/events/routes.rs index f09474c..e3a959f 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -9,14 +9,12 @@ use axum::{ }; use futures::stream::{Stream, StreamExt as _}; -use super::{ - extract::LastEventId, - types::{self, ResumePoint}, -}; +use super::{extract::LastEventId, types}; use crate::{ app::App, error::{Internal, Unauthorized}, login::{app::ValidateError, extract::Identity}, + repo::sequence::Sequence, }; #[cfg(test)] @@ -29,11 +27,9 @@ pub fn router() -> Router { async fn events( State(app): State, identity: Identity, - last_event_id: Option>, -) -> Result + std::fmt::Debug>, EventsError> { - let resume_at = last_event_id - .map(LastEventId::into_inner) - .unwrap_or_default(); + last_event_id: Option>, +) -> Result + std::fmt::Debug>, EventsError> { + let resume_at = last_event_id.map(LastEventId::into_inner); let stream = app.events().subscribe(resume_at).await?; let stream = app.logins().limit_stream(identity.token, stream).await?; @@ -46,7 +42,7 @@ struct Events(S); impl IntoResponse for Events where - S: Stream + Send + 'static, + S: Stream + Send + 'static, { fn into_response(self) -> Response { let Self(stream) = self; @@ -57,14 +53,12 @@ where } } -impl TryFrom for sse::Event { +impl TryFrom for sse::Event { type Error = serde_json::Error; - fn try_from(value: types::ResumableEvent) -> Result { - let types::ResumableEvent(resume_at, data) = value; - - let id = serde_json::to_string(&resume_at)?; - let data = serde_json::to_string_pretty(&data)?; + fn try_from(event: types::ChannelEvent) -> Result { + let id = serde_json::to_string(&event.sequence)?; + let data = serde_json::to_string_pretty(&event)?; let event = Self::default().id(id).data(data); diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs index 820192d..1cfca4f 100644 --- a/src/events/routes/test.rs +++ b/src/events/routes/test.rs @@ -5,7 +5,7 @@ use futures::{ }; use crate::{ - events::{routes, types}, + events::routes, test::fixtures::{self, future::Immediately as _}, }; @@ -28,7 +28,7 @@ async fn includes_historical_message() { // Verify the structure of the response. - let types::ResumableEvent(_, event) = events + let event = events .filter(fixtures::filter::messages()) .next() .immediately() @@ -58,7 +58,7 @@ async fn includes_live_message() { let sender = fixtures::login::create(&app).await; let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - let types::ResumableEvent(_, event) = events + let event = events .filter(fixtures::filter::messages()) .next() .immediately() @@ -108,9 +108,7 @@ async fn includes_multiple_channels() { .await; for message in &messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| { event == message })); + assert!(events.iter().any(|event| { event == message })); } } @@ -138,12 +136,11 @@ async fn sequential_messages() { // Verify the structure of the response. - let mut events = - events.filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))); + let mut events = events.filter(|event| future::ready(messages.contains(event))); // Verify delivery in order for message in &messages { - let types::ResumableEvent(_, event) = events + let event = events .next() .immediately() .await @@ -179,7 +176,7 @@ async fn resumes_from() { .await .expect("subscribe never fails"); - let types::ResumableEvent(last_event_id, event) = events + let event = events .filter(fixtures::filter::messages()) .next() .immediately() @@ -188,7 +185,7 @@ async fn resumes_from() { assert_eq!(initial_message, event); - last_event_id + event.sequence }; // Resume after disconnect @@ -205,9 +202,7 @@ async fn resumes_from() { .await; for message in &later_messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + assert!(events.iter().any(|event| event == message)); } } @@ -259,14 +254,12 @@ async fn serial_resume() { .await; for message in &initial_messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + assert!(events.iter().any(|event| event == message)); } - let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); + let event = events.last().expect("this vec is non-empty"); - id.to_owned() + event.sequence }; // Resume after disconnect @@ -296,14 +289,12 @@ async fn serial_resume() { .await; for message in &resume_messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + assert!(events.iter().any(|event| event == message)); } - let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); + let event = events.last().expect("this vec is non-empty"); - id.to_owned() + event.sequence }; // Resume after disconnect a second time @@ -335,9 +326,7 @@ async fn serial_resume() { // This set of messages, in particular, _should not_ include any prior // messages from `initial_messages` or `resume_messages`. for message in &final_messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + assert!(events.iter().any(|event| event == message)); } }; } @@ -375,7 +364,7 @@ async fn terminates_on_token_expiry() { ]; assert!(events - .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))) + .filter(|event| future::ready(messages.contains(event))) .next() .immediately() .await @@ -417,7 +406,7 @@ async fn terminates_on_logout() { ]; assert!(events - .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))) + .filter(|event| future::ready(messages.contains(event))) .next() .immediately() .await diff --git a/src/events/types.rs b/src/events/types.rs index d954512..aca3af4 100644 --- a/src/events/types.rs +++ b/src/events/types.rs @@ -1,84 +1,13 @@ -use std::collections::BTreeMap; - use crate::{ clock::DateTime, repo::{ channel::{self, Channel}, login::Login, message, + sequence::Sequence, }, }; -#[derive( - Debug, - Default, - Eq, - Ord, - PartialEq, - PartialOrd, - Clone, - Copy, - serde::Serialize, - serde::Deserialize, - sqlx::Type, -)] -#[serde(transparent)] -#[sqlx(transparent)] -pub struct Sequence(i64); - -impl Sequence { - pub fn next(self) -> Self { - let Self(current) = self; - Self(current + 1) - } -} - -// For the purposes of event replay, a resume point is a vector of resume -// elements. A resume element associates a channel (by ID) with the latest event -// seen in that channel so far. Replaying the event stream can restart at a -// predictable point - hence the name. These values can be serialized and sent -// to the client as JSON dicts, then rehydrated to recover the resume point at a -// later time. -// -// Using a sorted map ensures that there is a canonical representation for -// each resume point. -#[derive(Clone, Debug, Default, PartialEq, PartialOrd, serde::Deserialize, serde::Serialize)] -#[serde(transparent)] -pub struct ResumePoint(BTreeMap); - -impl ResumePoint { - pub fn advance<'e>(&mut self, event: impl Into>) { - let Self(elements) = self; - let ResumeElement(channel, sequence) = event.into(); - elements.insert(channel.clone(), sequence); - } - - pub fn forget<'e>(&mut self, event: impl Into>) { - let Self(elements) = self; - let ResumeElement(channel, _) = event.into(); - elements.remove(channel); - } - - pub fn get(&self, channel: &channel::Id) -> Option { - let Self(elements) = self; - elements.get(channel).copied() - } - - pub fn not_after<'e>(&self, event: impl Into>) -> bool { - let Self(elements) = self; - let ResumeElement(channel, sequence) = event.into(); - - elements - .get(channel) - .map_or(true, |resume_at| resume_at < &sequence) - } -} - -pub struct ResumeElement<'i>(&'i channel::Id, Sequence); - -#[derive(Clone, Debug)] -pub struct ResumableEvent(pub ResumePoint, pub ChannelEvent); - #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct ChannelEvent { #[serde(skip)] @@ -92,7 +21,7 @@ impl ChannelEvent { pub fn created(channel: Channel) -> Self { Self { at: channel.created_at, - sequence: Sequence::default(), + sequence: channel.created_sequence, data: CreatedEvent { channel }.into(), } } @@ -107,9 +36,9 @@ impl ChannelEvent { } } -impl<'c> From<&'c ChannelEvent> for ResumeElement<'c> { +impl<'c> From<&'c ChannelEvent> for Sequence { fn from(event: &'c ChannelEvent) -> Self { - Self(event.channel_id(), event.sequence) + event.sequence } } diff --git a/src/repo/channel.rs b/src/repo/channel.rs index 3c7468f..efc2ced 100644 --- a/src/repo/channel.rs +++ b/src/repo/channel.rs @@ -2,9 +2,10 @@ use std::fmt; use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; +use super::sequence::Sequence; use crate::{ clock::DateTime, - events::types::{self, Sequence}, + events::types::{self}, id::Id as BaseId, }; @@ -26,6 +27,8 @@ pub struct Channel { pub name: String, #[serde(skip)] pub created_at: DateTime, + #[serde(skip)] + pub created_sequence: Sequence, } impl<'c> Channels<'c> { @@ -33,25 +36,25 @@ impl<'c> Channels<'c> { &mut self, name: &str, created_at: &DateTime, + created_sequence: Sequence, ) -> Result { let id = Id::generate(); - let sequence = Sequence::default(); - let channel = sqlx::query_as!( Channel, r#" insert - into channel (id, name, created_at, last_sequence) + into channel (id, name, created_at, created_sequence) values ($1, $2, $3, $4) returning id as "id: Id", name, - created_at as "created_at: DateTime" + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" "#, id, name, created_at, - sequence, + created_sequence, ) .fetch_one(&mut *self.0) .await?; @@ -66,7 +69,8 @@ impl<'c> Channels<'c> { select id as "id: Id", name, - created_at as "created_at: DateTime" + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" from channel where id = $1 "#, @@ -85,7 +89,8 @@ impl<'c> Channels<'c> { select id as "id: Id", name, - created_at as "created_at: DateTime" + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" from channel order by channel.name "#, @@ -96,11 +101,34 @@ impl<'c> Channels<'c> { Ok(channels) } - pub async fn delete_expired( + pub async fn replay( + &mut self, + resume_at: Option, + ) -> Result, sqlx::Error> { + let channels = sqlx::query_as!( + Channel, + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where coalesce(created_sequence > $1, true) + "#, + resume_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } + + pub async fn delete( &mut self, channel: &Channel, - sequence: Sequence, deleted_at: &DateTime, + deleted_sequence: Sequence, ) -> Result { let channel = channel.id.clone(); sqlx::query_scalar!( @@ -115,7 +143,7 @@ impl<'c> Channels<'c> { .await?; Ok(types::ChannelEvent { - sequence, + sequence: deleted_sequence, at: *deleted_at, data: types::DeletedEvent { channel }.into(), }) @@ -128,7 +156,8 @@ impl<'c> Channels<'c> { select channel.id as "id: Id", channel.name, - channel.created_at as "created_at: DateTime" + channel.created_at as "created_at: DateTime", + channel.created_sequence as "created_sequence: Sequence" from channel left join message where created_at < $1 diff --git a/src/repo/mod.rs b/src/repo/mod.rs index cb9d7c8..8f271f4 100644 --- a/src/repo/mod.rs +++ b/src/repo/mod.rs @@ -3,4 +3,5 @@ pub mod error; pub mod login; pub mod message; pub mod pool; +pub mod sequence; pub mod token; diff --git a/src/repo/sequence.rs b/src/repo/sequence.rs new file mode 100644 index 0000000..8fe9dab --- /dev/null +++ b/src/repo/sequence.rs @@ -0,0 +1,45 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +pub trait Provider { + fn sequence(&mut self) -> Sequences; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn sequence(&mut self) -> Sequences { + Sequences(self) + } +} + +#[derive( + Clone, + Copy, + Debug, + Eq, + Ord, + PartialEq, + PartialOrd, + serde::Deserialize, + serde::Serialize, + sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Sequence(i64); + +pub struct Sequences<'t>(&'t mut SqliteConnection); + +impl<'c> Sequences<'c> { + pub async fn next(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + update event_sequence + set last_value = last_value + 1 + returning last_value as "next_value: Sequence" + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } +} diff --git a/src/test/fixtures/filter.rs b/src/test/fixtures/filter.rs index fbebced..c31fa58 100644 --- a/src/test/fixtures/filter.rs +++ b/src/test/fixtures/filter.rs @@ -2,14 +2,10 @@ use futures::future; use crate::events::types; -pub fn messages() -> impl FnMut(&types::ResumableEvent) -> future::Ready { - |types::ResumableEvent(_, event)| { - future::ready(matches!(event.data, types::ChannelEventData::Message(_))) - } +pub fn messages() -> impl FnMut(&types::ChannelEvent) -> future::Ready { + |event| future::ready(matches!(event.data, types::ChannelEventData::Message(_))) } -pub fn created() -> impl FnMut(&types::ResumableEvent) -> future::Ready { - |types::ResumableEvent(_, event)| { - future::ready(matches!(event.data, types::ChannelEventData::Created(_))) - } +pub fn created() -> impl FnMut(&types::ChannelEvent) -> future::Ready { + |event| future::ready(matches!(event.data, types::ChannelEventData::Created(_))) } -- cgit v1.2.3