From 7645411bcf7201e3a4927566da78080dc6a84ccf Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Tue, 1 Oct 2024 20:32:57 -0400 Subject: Prevent racing between `limit_stream` and logging out. --- ...89c4147c994df46347a9ce2030ae04a52ccfc0c40c.json | 20 ++++++++++ src/error.rs | 8 ++++ src/events/routes.rs | 26 +++++++++++-- src/login/app.rs | 43 ++++++++++++++++++---- src/login/extract.rs | 6 +-- src/login/routes.rs | 11 ++++-- src/repo/token.rs | 15 ++++++++ 7 files changed, 111 insertions(+), 18 deletions(-) create mode 100644 .sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json diff --git a/.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json b/.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json new file mode 100644 index 0000000..e07ad25 --- /dev/null +++ b/.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n select id as \"id: Id\"\n from token\n where id = $1\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false + ] + }, + "hash": "cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c" +} diff --git a/src/error.rs b/src/error.rs index 6e797b4..8792a1d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -61,3 +61,11 @@ impl fmt::Display for Id { self.0.fmt(f) } } + +pub struct Unauthorized; + +impl IntoResponse for Unauthorized { + fn into_response(self) -> Response { + (StatusCode::UNAUTHORIZED, "unauthorized").into_response() + } +} diff --git a/src/events/routes.rs b/src/events/routes.rs index ec9dae2..f09474c 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -13,7 +13,11 @@ use super::{ extract::LastEventId, types::{self, ResumePoint}, }; -use crate::{app::App, error::Internal, login::extract::Identity}; +use crate::{ + app::App, + error::{Internal, Unauthorized}, + login::{app::ValidateError, extract::Identity}, +}; #[cfg(test)] mod test; @@ -26,13 +30,13 @@ async fn events( State(app): State, identity: Identity, last_event_id: Option>, -) -> Result + std::fmt::Debug>, Internal> { +) -> Result + std::fmt::Debug>, EventsError> { let resume_at = last_event_id .map(LastEventId::into_inner) .unwrap_or_default(); let stream = app.events().subscribe(resume_at).await?; - let stream = app.logins().limit_stream(identity.token, stream); + let stream = app.logins().limit_stream(identity.token, stream).await?; Ok(Events(stream)) } @@ -67,3 +71,19 @@ impl TryFrom for sse::Event { Ok(event) } } + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum EventsError { + DatabaseError(#[from] sqlx::Error), + ValidateError(#[from] ValidateError), +} + +impl IntoResponse for EventsError { + fn into_response(self) -> Response { + match self { + Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(), + other => Internal::from(other).into_response(), + } + } +} diff --git a/src/login/app.rs b/src/login/app.rs index 182c62c..95f0a07 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -81,28 +81,55 @@ impl<'a> Logins<'a> { Ok(login) } - pub fn limit_stream( + pub async fn limit_stream( &self, token: token::Id, events: impl Stream + std::fmt::Debug, - ) -> impl Stream + std::fmt::Debug + ) -> Result + std::fmt::Debug, ValidateError> where E: std::fmt::Debug, { - let token_events = self - .logins - .subscribe() + // Subscribe, first. + let token_events = self.logins.subscribe(); + + // Check that the token is valid at this point in time, second. If it is, then + // any future revocations will appear in the subscription. If not, bail now. + // + // It's possible, otherwise, to get to this point with a token that _was_ valid + // at the start of the request, but which was invalided _before_ the + // `subscribe()` call. In that case, the corresponding revocation event will + // simply be missed, since the `token_events` stream subscribed after the fact. + // This check cancels guarding the stream here. + // + // Yes, this is a weird niche edge case. Most things don't double-check, because + // they aren't expected to run long enough for the token's revocation to + // matter. Supervising a stream, on the other hand, will run for a + // _long_ time; if we miss the race here, we'll never actually carry out the + // supervision. + let mut tx = self.db.begin().await?; + tx.tokens() + .require(&token) + .await + .not_found(|| ValidateError::InvalidToken)?; + tx.commit().await?; + + // Then construct the guarded stream. First, project both streams into + // `GuardedEvent`. + let token_events = token_events .filter(move |event| future::ready(event.token == token)) .map(|_| GuardedEvent::TokenRevoked); - let events = events.map(|event| GuardedEvent::Event(event)); - stream::select(token_events, events).scan((), |(), event| { + // Merge the two streams, then unproject them, stopping at + // `GuardedEvent::TokenRevoked`. + let stream = stream::select(token_events, events).scan((), |(), event| { future::ready(match event { GuardedEvent::Event(event) => Some(event), GuardedEvent::TokenRevoked => None, }) - }) + }); + + Ok(stream) } pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { diff --git a/src/login/extract.rs b/src/login/extract.rs index b585565..bfdbe8d 100644 --- a/src/login/extract.rs +++ b/src/login/extract.rs @@ -2,7 +2,7 @@ use std::fmt; use axum::{ extract::{FromRequestParts, State}, - http::{request::Parts, StatusCode}, + http::request::Parts, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use axum_extra::extract::cookie::{Cookie, CookieJar}; @@ -10,7 +10,7 @@ use axum_extra::extract::cookie::{Cookie, CookieJar}; use crate::{ app::App, clock::RequestedAt, - error::Internal, + error::{Internal, Unauthorized}, login::app::ValidateError, repo::{login::Login, token}, }; @@ -166,7 +166,7 @@ where { fn into_response(self) -> Response { match self { - Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(), + Self::Unauthorized => Unauthorized.into_response(), Self::Failure(e) => e.into_response(), } } diff --git a/src/login/routes.rs b/src/login/routes.rs index 8d9e938..d7cb9b1 100644 --- a/src/login/routes.rs +++ b/src/login/routes.rs @@ -7,7 +7,11 @@ use axum::{ }; use crate::{ - app::App, clock::RequestedAt, error::Internal, password::Password, repo::login::Login, + app::App, + clock::RequestedAt, + error::{Internal, Unauthorized}, + password::Password, + repo::login::Login, }; use super::{app, extract::IdentityToken}; @@ -66,6 +70,7 @@ impl IntoResponse for LoginError { let Self(error) = self; match error { app::LoginError::Rejected => { + // not error::Unauthorized due to differing messaging (StatusCode::UNAUTHORIZED, "invalid name or password").into_response() } other => Internal::from(other).into_response(), @@ -103,9 +108,7 @@ enum LogoutError { impl IntoResponse for LogoutError { fn into_response(self) -> Response { match self { - error @ Self::ValidateError(app::ValidateError::InvalidToken) => { - (StatusCode::UNAUTHORIZED, error.to_string()).into_response() - } + Self::ValidateError(app::ValidateError::InvalidToken) => Unauthorized.into_response(), other => Internal::from(other).into_response(), } } diff --git a/src/repo/token.rs b/src/repo/token.rs index d96c094..1663f5e 100644 --- a/src/repo/token.rs +++ b/src/repo/token.rs @@ -47,6 +47,21 @@ impl<'c> Tokens<'c> { Ok(secret) } + pub async fn require(&mut self, token: &Id) -> Result<(), sqlx::Error> { + sqlx::query_scalar!( + r#" + select id as "id: Id" + from token + where id = $1 + "#, + token, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(()) + } + // Revoke a token by its secret. pub async fn revoke(&mut self, token: &Id) -> Result<(), sqlx::Error> { sqlx::query_scalar!( -- cgit v1.2.3 From b8392a5fe824eff46f912a58885546e7b0f37e6f Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Tue, 1 Oct 2024 22:30:04 -0400 Subject: Track event sequences globally, not per channel. Per-channel event sequences were a cute idea, but it made reasoning about event resumption much, much harder (case in point: recovering the order of events in a partially-ordered collection is quadratic, since it's basically graph sort). The minor overhead of a global sequence number is likely tolerable, and this simplifies both the API and the internals. --- ...f6b07e69ab792b7365f2eb2831f7a2ac13e2ecf323.json | 38 +++++++ ...8c64a38a3f73b112e74b7318ee8e52e475866d8cfd.json | 32 ------ ...8f227b09c3e4b0c9c0202c7cbe3fba93213ea100cf.json | 44 ------- ...53cf075037c794cae08f79a689c7a037aa68d7c00c.json | 20 ---- ...05a144db4901c302bbcd3e76da6c61742ac44345c9.json | 44 +++++++ ...6854398e2151ba2dba10c03a9d2d93184141f1425c.json | 44 +++++++ ...aaf423b1fd14ed9e252d7d9c5323feafb0b9159259.json | 32 ------ ...b8df33083c2da765dfda3023c78c25c06735670457.json | 38 +++++++ ...8cb132479c6e7a2301d576af298da570f3effdc106.json | 50 -------- ...926096285d50afb88a326cff0ecab96058a2f6d93a.json | 32 ------ ...72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json | 74 ++++++++++++ ...8a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json | 38 +++++++ ...0aace35b06db0071c5f257b7f71349966bcdadfcb5.json | 20 ++++ ...fc8c187294a4661c18c2d820f4379dfd82138a8f77.json | 38 +++++++ ...e61c4b10c6f90b1221e963db69c8e6d23e99012ecf.json | 32 ------ ...a2d368676f279af866d0840d6c2c093b87b1eadd8c.json | 38 ------- ...96a0e7b12a920f9827aff2b05ee0364ff7688a38ae.json | 38 +++++++ migrations/20241002003606_global_sequence.sql | 126 +++++++++++++++++++++ src/channel/app.rs | 14 ++- src/channel/routes/test/on_create.rs | 5 +- src/channel/routes/test/on_send.rs | 4 +- src/events/app.rs | 76 +++++-------- src/events/repo/message.rs | 79 ++++++------- src/events/routes.rs | 26 ++--- src/events/routes/test.rs | 47 +++----- src/events/types.rs | 79 +------------ src/repo/channel.rs | 53 +++++++-- src/repo/mod.rs | 1 + src/repo/sequence.rs | 45 ++++++++ src/test/fixtures/filter.rs | 12 +- 30 files changed, 696 insertions(+), 523 deletions(-) create mode 100644 .sqlx/query-023b1e263b68a483704ae5f6b07e69ab792b7365f2eb2831f7a2ac13e2ecf323.json delete mode 100644 .sqlx/query-22f313d9afcdd02df74a8b8c64a38a3f73b112e74b7318ee8e52e475866d8cfd.json delete mode 100644 .sqlx/query-2310fe5b8e88e314eb200d8f227b09c3e4b0c9c0202c7cbe3fba93213ea100cf.json delete mode 100644 .sqlx/query-397bdfdb77651e3e65e9ec53cf075037c794cae08f79a689c7a037aa68d7c00c.json create mode 100644 .sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json create mode 100644 .sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json delete mode 100644 .sqlx/query-6a782686e163e65f5e03e4aaf423b1fd14ed9e252d7d9c5323feafb0b9159259.json create mode 100644 .sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json delete mode 100644 .sqlx/query-760d3532e1613fd9f79ac98cb132479c6e7a2301d576af298da570f3effdc106.json delete mode 100644 .sqlx/query-7ccae3dde1aba5f22cf9e3926096285d50afb88a326cff0ecab96058a2f6d93a.json create mode 100644 .sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json create mode 100644 .sqlx/query-7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json create mode 100644 .sqlx/query-7fc3094944d5133fd8b2d80aace35b06db0071c5f257b7f71349966bcdadfcb5.json create mode 100644 .sqlx/query-9386cdaa2cb41f5a7e19d2fc8c187294a4661c18c2d820f4379dfd82138a8f77.json delete mode 100644 .sqlx/query-aeafe536f36593bfd1080ee61c4b10c6f90b1221e963db69c8e6d23e99012ecf.json delete mode 100644 .sqlx/query-df3656771c3cb6851e0c54a2d368676f279af866d0840d6c2c093b87b1eadd8c.json create mode 100644 .sqlx/query-f6909336ab05b7ad423c7b96a0e7b12a920f9827aff2b05ee0364ff7688a38ae.json create mode 100644 migrations/20241002003606_global_sequence.sql create mode 100644 src/repo/sequence.rs diff --git a/.sqlx/query-023b1e263b68a483704ae5f6b07e69ab792b7365f2eb2831f7a2ac13e2ecf323.json b/.sqlx/query-023b1e263b68a483704ae5f6b07e69ab792b7365f2eb2831f7a2ac13e2ecf323.json new file mode 100644 index 0000000..cc23359 --- /dev/null +++ b/.sqlx/query-023b1e263b68a483704ae5f6b07e69ab792b7365f2eb2831f7a2ac13e2ecf323.json @@ -0,0 +1,38 @@ +{ + "db_name": "SQLite", + "query": "\n insert\n into channel (id, name, created_at, created_sequence)\n values ($1, $2, $3, $4)\n returning\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\",\n created_sequence as \"created_sequence: Sequence\"\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 4 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "023b1e263b68a483704ae5f6b07e69ab792b7365f2eb2831f7a2ac13e2ecf323" +} diff --git a/.sqlx/query-22f313d9afcdd02df74a8b8c64a38a3f73b112e74b7318ee8e52e475866d8cfd.json b/.sqlx/query-22f313d9afcdd02df74a8b8c64a38a3f73b112e74b7318ee8e52e475866d8cfd.json deleted file mode 100644 index 3d5d06c..0000000 --- a/.sqlx/query-22f313d9afcdd02df74a8b8c64a38a3f73b112e74b7318ee8e52e475866d8cfd.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\"\n from channel\n where id = $1\n ", - "describe": { - "columns": [ - { - "name": "id: Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "22f313d9afcdd02df74a8b8c64a38a3f73b112e74b7318ee8e52e475866d8cfd" -} diff --git a/.sqlx/query-2310fe5b8e88e314eb200d8f227b09c3e4b0c9c0202c7cbe3fba93213ea100cf.json b/.sqlx/query-2310fe5b8e88e314eb200d8f227b09c3e4b0c9c0202c7cbe3fba93213ea100cf.json deleted file mode 100644 index 1bd4116..0000000 --- a/.sqlx/query-2310fe5b8e88e314eb200d8f227b09c3e4b0c9c0202c7cbe3fba93213ea100cf.json +++ /dev/null @@ -1,44 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n\t\t\t\tinsert into message\n\t\t\t\t\t(id, channel, sequence, sender, body, sent_at)\n\t\t\t\tvalues ($1, $2, $3, $4, $5, $6)\n\t\t\t\treturning\n\t\t\t\t\tid as \"id: message::Id\",\n sequence as \"sequence: Sequence\",\n\t\t\t\t\tsender as \"sender: login::Id\",\n\t\t\t\t\tbody,\n\t\t\t\t\tsent_at as \"sent_at: DateTime\"\n\t\t\t", - "describe": { - "columns": [ - { - "name": "id: message::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "sequence: Sequence", - "ordinal": 1, - "type_info": "Integer" - }, - { - "name": "sender: login::Id", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "body", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "sent_at: DateTime", - "ordinal": 4, - "type_info": "Text" - } - ], - "parameters": { - "Right": 6 - }, - "nullable": [ - false, - false, - false, - false, - false - ] - }, - "hash": "2310fe5b8e88e314eb200d8f227b09c3e4b0c9c0202c7cbe3fba93213ea100cf" -} diff --git a/.sqlx/query-397bdfdb77651e3e65e9ec53cf075037c794cae08f79a689c7a037aa68d7c00c.json b/.sqlx/query-397bdfdb77651e3e65e9ec53cf075037c794cae08f79a689c7a037aa68d7c00c.json deleted file mode 100644 index 5cb7282..0000000 --- a/.sqlx/query-397bdfdb77651e3e65e9ec53cf075037c794cae08f79a689c7a037aa68d7c00c.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n update channel\n set last_sequence = last_sequence + 1\n where id = $1\n returning last_sequence as \"next_sequence: Sequence\"\n ", - "describe": { - "columns": [ - { - "name": "next_sequence: Sequence", - "ordinal": 0, - "type_info": "Integer" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false - ] - }, - "hash": "397bdfdb77651e3e65e9ec53cf075037c794cae08f79a689c7a037aa68d7c00c" -} diff --git a/.sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json b/.sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json new file mode 100644 index 0000000..494e1db --- /dev/null +++ b/.sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json @@ -0,0 +1,44 @@ +{ + "db_name": "SQLite", + "query": "\n\t\t\t\tinsert into message\n\t\t\t\t\t(id, channel, sender, sent_at, sent_sequence, body)\n\t\t\t\tvalues ($1, $2, $3, $4, $5, $6)\n\t\t\t\treturning\n\t\t\t\t\tid as \"id: message::Id\",\n\t\t\t\t\tsender as \"sender: login::Id\",\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\",\n\t\t\t\t\tbody\n\t\t\t", + "describe": { + "columns": [ + { + "name": "id: message::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "sender: login::Id", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "sent_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "sent_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + }, + { + "name": "body", + "ordinal": 4, + "type_info": "Text" + } + ], + "parameters": { + "Right": 6 + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9" +} diff --git a/.sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json b/.sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json new file mode 100644 index 0000000..820b43f --- /dev/null +++ b/.sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json @@ -0,0 +1,44 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n channel.created_at as \"channel_created_at: DateTime\",\n channel.created_sequence as \"channel_created_sequence: Sequence\",\n message.id as \"message: message::Id\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where sent_at < $1\n ", + "describe": { + "columns": [ + { + "name": "channel_id: channel::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "channel_created_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "channel_created_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + }, + { + "name": "message: message::Id", + "ordinal": 4, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c" +} diff --git a/.sqlx/query-6a782686e163e65f5e03e4aaf423b1fd14ed9e252d7d9c5323feafb0b9159259.json b/.sqlx/query-6a782686e163e65f5e03e4aaf423b1fd14ed9e252d7d9c5323feafb0b9159259.json deleted file mode 100644 index ae298d6..0000000 --- a/.sqlx/query-6a782686e163e65f5e03e4aaf423b1fd14ed9e252d7d9c5323feafb0b9159259.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n channel.id as \"id: Id\",\n channel.name,\n channel.created_at as \"created_at: DateTime\"\n from channel\n left join message\n where created_at < $1\n and message.id is null\n ", - "describe": { - "columns": [ - { - "name": "id: Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "6a782686e163e65f5e03e4aaf423b1fd14ed9e252d7d9c5323feafb0b9159259" -} diff --git a/.sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json b/.sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json new file mode 100644 index 0000000..b34443f --- /dev/null +++ b/.sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json @@ -0,0 +1,38 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"id: Id\",\n channel.name,\n channel.created_at as \"created_at: DateTime\",\n channel.created_sequence as \"created_sequence: Sequence\"\n from channel\n left join message\n where created_at < $1\n and message.id is null\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457" +} diff --git a/.sqlx/query-760d3532e1613fd9f79ac98cb132479c6e7a2301d576af298da570f3effdc106.json b/.sqlx/query-760d3532e1613fd9f79ac98cb132479c6e7a2301d576af298da570f3effdc106.json deleted file mode 100644 index beb9234..0000000 --- a/.sqlx/query-760d3532e1613fd9f79ac98cb132479c6e7a2301d576af298da570f3effdc106.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n\t\t\t\tselect\n\t\t\t\t\tmessage.id as \"id: message::Id\",\n sequence as \"sequence: Sequence\",\n\t\t\t\t\tlogin.id as \"sender_id: login::Id\",\n\t\t\t\t\tlogin.name as sender_name,\n\t\t\t\t\tmessage.body,\n\t\t\t\t\tmessage.sent_at as \"sent_at: DateTime\"\n\t\t\t\tfrom message\n\t\t\t\t\tjoin login on message.sender = login.id\n\t\t\t\twhere channel = $1\n\t\t\t\t\tand coalesce(sequence > $2, true)\n\t\t\t\torder by sequence asc\n\t\t\t", - "describe": { - "columns": [ - { - "name": "id: message::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "sequence: Sequence", - "ordinal": 1, - "type_info": "Integer" - }, - { - "name": "sender_id: login::Id", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "sender_name", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "body", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "sent_at: DateTime", - "ordinal": 5, - "type_info": "Text" - } - ], - "parameters": { - "Right": 2 - }, - "nullable": [ - false, - false, - false, - false, - false, - false - ] - }, - "hash": "760d3532e1613fd9f79ac98cb132479c6e7a2301d576af298da570f3effdc106" -} diff --git a/.sqlx/query-7ccae3dde1aba5f22cf9e3926096285d50afb88a326cff0ecab96058a2f6d93a.json b/.sqlx/query-7ccae3dde1aba5f22cf9e3926096285d50afb88a326cff0ecab96058a2f6d93a.json deleted file mode 100644 index 4ec7118..0000000 --- a/.sqlx/query-7ccae3dde1aba5f22cf9e3926096285d50afb88a326cff0ecab96058a2f6d93a.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\"\n from channel\n order by channel.name\n ", - "describe": { - "columns": [ - { - "name": "id: Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - } - ], - "parameters": { - "Right": 0 - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "7ccae3dde1aba5f22cf9e3926096285d50afb88a326cff0ecab96058a2f6d93a" -} diff --git a/.sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json b/.sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json new file mode 100644 index 0000000..f546438 --- /dev/null +++ b/.sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json @@ -0,0 +1,74 @@ +{ + "db_name": "SQLite", + "query": "\n\t\t\t\tselect\n\t\t\t\t\tmessage.id as \"id: message::Id\",\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n channel.created_at as \"channel_created_at: DateTime\",\n channel.created_sequence as \"channel_created_sequence: Sequence\",\n\t\t\t\t\tsender.id as \"sender_id: login::Id\",\n\t\t\t\t\tsender.name as sender_name,\n message.sent_at as \"sent_at: DateTime\",\n message.sent_sequence as \"sent_sequence: Sequence\",\n message.body\n\t\t\t\tfrom message\n join channel on message.channel = channel.id\n\t\t\t\t\tjoin login as sender on message.sender = sender.id\n\t\t\t\twhere coalesce(message.sent_sequence > $1, true)\n\t\t\t\torder by sent_sequence asc\n\t\t\t", + "describe": { + "columns": [ + { + "name": "id: message::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_id: channel::Id", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "channel_created_at: DateTime", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "channel_created_sequence: Sequence", + "ordinal": 4, + "type_info": "Integer" + }, + { + "name": "sender_id: login::Id", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "sender_name", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "sent_at: DateTime", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "sent_sequence: Sequence", + "ordinal": 8, + "type_info": "Integer" + }, + { + "name": "body", + "ordinal": 9, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c" +} diff --git a/.sqlx/query-7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json b/.sqlx/query-7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json new file mode 100644 index 0000000..3cc33cf --- /dev/null +++ b/.sqlx/query-7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json @@ -0,0 +1,38 @@ +{ + "db_name": "SQLite", + "query": "\n select\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\",\n created_sequence as \"created_sequence: Sequence\"\n from channel\n order by channel.name\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 0 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88" +} diff --git a/.sqlx/query-7fc3094944d5133fd8b2d80aace35b06db0071c5f257b7f71349966bcdadfcb5.json b/.sqlx/query-7fc3094944d5133fd8b2d80aace35b06db0071c5f257b7f71349966bcdadfcb5.json new file mode 100644 index 0000000..b5bc371 --- /dev/null +++ b/.sqlx/query-7fc3094944d5133fd8b2d80aace35b06db0071c5f257b7f71349966bcdadfcb5.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n update event_sequence\n set last_value = last_value + 1\n returning last_value as \"next_value: Sequence\"\n ", + "describe": { + "columns": [ + { + "name": "next_value: Sequence", + "ordinal": 0, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 0 + }, + "nullable": [ + false + ] + }, + "hash": "7fc3094944d5133fd8b2d80aace35b06db0071c5f257b7f71349966bcdadfcb5" +} diff --git a/.sqlx/query-9386cdaa2cb41f5a7e19d2fc8c187294a4661c18c2d820f4379dfd82138a8f77.json b/.sqlx/query-9386cdaa2cb41f5a7e19d2fc8c187294a4661c18c2d820f4379dfd82138a8f77.json new file mode 100644 index 0000000..e9c3967 --- /dev/null +++ b/.sqlx/query-9386cdaa2cb41f5a7e19d2fc8c187294a4661c18c2d820f4379dfd82138a8f77.json @@ -0,0 +1,38 @@ +{ + "db_name": "SQLite", + "query": "\n select\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\",\n created_sequence as \"created_sequence: Sequence\"\n from channel\n where coalesce(created_sequence > $1, true)\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "9386cdaa2cb41f5a7e19d2fc8c187294a4661c18c2d820f4379dfd82138a8f77" +} diff --git a/.sqlx/query-aeafe536f36593bfd1080ee61c4b10c6f90b1221e963db69c8e6d23e99012ecf.json b/.sqlx/query-aeafe536f36593bfd1080ee61c4b10c6f90b1221e963db69c8e6d23e99012ecf.json deleted file mode 100644 index 5c27826..0000000 --- a/.sqlx/query-aeafe536f36593bfd1080ee61c4b10c6f90b1221e963db69c8e6d23e99012ecf.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n insert\n into channel (id, name, created_at, last_sequence)\n values ($1, $2, $3, $4)\n returning\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\"\n ", - "describe": { - "columns": [ - { - "name": "id: Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - } - ], - "parameters": { - "Right": 4 - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "aeafe536f36593bfd1080ee61c4b10c6f90b1221e963db69c8e6d23e99012ecf" -} diff --git a/.sqlx/query-df3656771c3cb6851e0c54a2d368676f279af866d0840d6c2c093b87b1eadd8c.json b/.sqlx/query-df3656771c3cb6851e0c54a2d368676f279af866d0840d6c2c093b87b1eadd8c.json deleted file mode 100644 index 87e478e..0000000 --- a/.sqlx/query-df3656771c3cb6851e0c54a2d368676f279af866d0840d6c2c093b87b1eadd8c.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n channel.created_at as \"channel_created_at: DateTime\",\n message.id as \"message: message::Id\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where sent_at < $1\n ", - "describe": { - "columns": [ - { - "name": "channel_id: channel::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "channel_name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "channel_created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "message: message::Id", - "ordinal": 3, - "type_info": "Text" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false - ] - }, - "hash": "df3656771c3cb6851e0c54a2d368676f279af866d0840d6c2c093b87b1eadd8c" -} diff --git a/.sqlx/query-f6909336ab05b7ad423c7b96a0e7b12a920f9827aff2b05ee0364ff7688a38ae.json b/.sqlx/query-f6909336ab05b7ad423c7b96a0e7b12a920f9827aff2b05ee0364ff7688a38ae.json new file mode 100644 index 0000000..ded48e1 --- /dev/null +++ b/.sqlx/query-f6909336ab05b7ad423c7b96a0e7b12a920f9827aff2b05ee0364ff7688a38ae.json @@ -0,0 +1,38 @@ +{ + "db_name": "SQLite", + "query": "\n select\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\",\n created_sequence as \"created_sequence: Sequence\"\n from channel\n where id = $1\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "f6909336ab05b7ad423c7b96a0e7b12a920f9827aff2b05ee0364ff7688a38ae" +} diff --git a/migrations/20241002003606_global_sequence.sql b/migrations/20241002003606_global_sequence.sql new file mode 100644 index 0000000..198b585 --- /dev/null +++ b/migrations/20241002003606_global_sequence.sql @@ -0,0 +1,126 @@ +create table event_sequence ( + last_value bigint + not null +); + +create unique index event_sequence_singleton +on event_sequence (0); + +-- Attempt to assign events sent so far a globally-unique sequence number, +-- maintaining an approximation of the order they were sent in. This can +-- introduce small ordering anomalies (where the resulting sequence differs +-- from the order they were sent in) for events that were sent close in time; +-- I've gone with chronological order here as it's the closest thing we have to +-- a global ordering, and because the results will be intuitive to most users. +create temporary table raw_event ( + type text + not null, + at text + not null, + channel text + unique, + message text + unique, + check ((channel is not null and message is null) or (message is not null and channel is null)) +); + +insert into raw_event (type, at, channel) +select + 'channel' as type, + created_at as at, + id as channel +from channel; + +insert into raw_event (type, at, message) +select + 'message' as type, + sent_at as at, + id as message +from message; + +create temporary table event ( + type text + not null, + sequence + unique + not null, + at text + not null, + channel text + unique, + message text + unique, + check ((channel is not null and message is null) or (message is not null and channel is null)) +); + +insert into event +select + type, + rank() over (order by at) - 1 as sequence, + at, + channel, + message +from raw_event; + +drop table raw_event; + +alter table channel rename to old_channel; +alter table message rename to old_message; + +create table channel ( + id text + not null + primary key, + name text + unique + not null, + created_sequence bigint + unique + not null, + created_at text + not null +); + +insert into channel +select + c.id, + c.name, + e.sequence, + c.created_at +from old_channel as c join event as e + on e.channel = c.id; + +create table message ( + id text + not null + primary key, + channel text + not null + references channel (id), + sender text + not null + references login (id), + sent_sequence bigint + unique + not null, + sent_at text + not null, + body text + not null +); + +insert into message +select + m.id, + m.channel, + m.sender, + e.sequence, + m.sent_at, + m.body +from old_message as m join event as e + on e.message = m.id; + +insert into event_sequence +select coalesce(max(sequence), 0) from event; + +drop table event; diff --git a/src/channel/app.rs b/src/channel/app.rs index 70cda47..88f4170 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -3,8 +3,11 @@ use sqlx::sqlite::SqlitePool; use crate::{ clock::DateTime, - events::{broadcaster::Broadcaster, repo::message::Provider as _, types::ChannelEvent}, - repo::channel::{Channel, Provider as _}, + events::{broadcaster::Broadcaster, types::ChannelEvent}, + repo::{ + channel::{Channel, Provider as _}, + sequence::Provider as _, + }, }; pub struct Channels<'a> { @@ -19,9 +22,10 @@ impl<'a> Channels<'a> { pub async fn create(&self, name: &str, created_at: &DateTime) -> Result { let mut tx = self.db.begin().await?; + let created_sequence = tx.sequence().next().await?; let channel = tx .channels() - .create(name, created_at) + .create(name, created_at, created_sequence) .await .map_err(|err| CreateError::from_duplicate_name(err, name))?; tx.commit().await?; @@ -49,10 +53,10 @@ impl<'a> Channels<'a> { let mut events = Vec::with_capacity(expired.len()); for channel in expired { - let sequence = tx.message_events().assign_sequence(&channel).await?; + let deleted_sequence = tx.sequence().next().await?; let event = tx .channels() - .delete_expired(&channel, sequence, relative_to) + .delete(&channel, relative_to, deleted_sequence) .await?; events.push(event); } diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index e2610a5..5deb88a 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -38,18 +38,17 @@ async fn new_channel() { let mut events = app .events() - .subscribe(types::ResumePoint::default()) + .subscribe(None) .await .expect("subscribing never fails") .filter(fixtures::filter::created()); - let types::ResumableEvent(_, event) = events + let event = events .next() .immediately() .await .expect("creation event published"); - assert_eq!(types::Sequence::default(), event.sequence); assert!(matches!( event.data, types::ChannelEventData::Created(event) diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 233518b..d37ed21 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -43,7 +43,7 @@ async fn messages_in_order() { let events = app .events() - .subscribe(types::ResumePoint::default()) + .subscribe(None) .await .expect("subscribing to a valid channel") .filter(fixtures::filter::messages()) @@ -51,7 +51,7 @@ async fn messages_in_order() { let events = events.collect::>().immediately().await; - for ((sent_at, message), types::ResumableEvent(_, event)) in requests.into_iter().zip(events) { + for ((sent_at, message), event) in requests.into_iter().zip(events) { assert_eq!(*sent_at, event.at); assert!(matches!( event.data, diff --git a/src/events/app.rs b/src/events/app.rs index db7f430..c15f11e 100644 --- a/src/events/app.rs +++ b/src/events/app.rs @@ -1,5 +1,3 @@ -use std::collections::BTreeMap; - use chrono::TimeDelta; use futures::{ future, @@ -11,7 +9,7 @@ use sqlx::sqlite::SqlitePool; use super::{ broadcaster::Broadcaster, repo::message::Provider as _, - types::{self, ChannelEvent, ResumePoint}, + types::{self, ChannelEvent}, }; use crate::{ clock::DateTime, @@ -19,6 +17,7 @@ use crate::{ channel::{self, Provider as _}, error::NotFound as _, login::Login, + sequence::{Provider as _, Sequence}, }, }; @@ -45,9 +44,10 @@ impl<'a> Events<'a> { .by_id(channel) .await .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; + let sent_sequence = tx.sequence().next().await?; let event = tx .message_events() - .create(login, &channel, body, sent_at) + .create(login, &channel, sent_at, sent_sequence, body) .await?; tx.commit().await?; @@ -64,10 +64,10 @@ impl<'a> Events<'a> { let mut events = Vec::with_capacity(expired.len()); for (channel, message) in expired { - let sequence = tx.message_events().assign_sequence(&channel).await?; + let deleted_sequence = tx.sequence().next().await?; let event = tx .message_events() - .delete_expired(&channel, &message, sequence, relative_to) + .delete(&channel, &message, relative_to, deleted_sequence) .await?; events.push(event); } @@ -83,42 +83,30 @@ impl<'a> Events<'a> { pub async fn subscribe( &self, - resume_at: ResumePoint, - ) -> Result + std::fmt::Debug, sqlx::Error> { - let mut tx = self.db.begin().await?; - let channels = tx.channels().all().await?; - - let created_events = { - let resume_at = resume_at.clone(); - let channels = channels.clone(); - stream::iter( - channels - .into_iter() - .map(ChannelEvent::created) - .filter(move |event| resume_at.not_after(event)), - ) - }; - + resume_at: Option, + ) -> Result + std::fmt::Debug, sqlx::Error> { // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. let live_messages = self.events.subscribe(); - let mut replays = BTreeMap::new(); - let mut resume_live_at = resume_at.clone(); - for channel in channels { - let replay = tx - .message_events() - .replay(&channel, resume_at.get(&channel.id)) - .await?; + let mut tx = self.db.begin().await?; + let channels = tx.channels().replay(resume_at).await?; - if let Some(last) = replay.last() { - resume_live_at.advance(last); - } + let channel_events = channels + .into_iter() + .map(ChannelEvent::created) + .filter(move |event| resume_at.map_or(true, |resume_at| event.sequence > resume_at)); - replays.insert(channel.id.clone(), replay); - } + let message_events = tx.message_events().replay(resume_at).await?; + + let mut replay_events = channel_events + .into_iter() + .chain(message_events.into_iter()) + .collect::>(); + replay_events.sort_by_key(|event| event.sequence); + let resume_live_at = replay_events.last().map(|event| event.sequence); - let replay = stream::select_all(replays.into_values().map(stream::iter)); + let replay = stream::iter(replay_events); // no skip_expired or resume transforms for stored_messages, as it's // constructed not to contain messages meeting either criterion. @@ -132,25 +120,13 @@ impl<'a> Events<'a> { // stored_messages. .filter(Self::resume(resume_live_at)); - Ok(created_events.chain(replay).chain(live_messages).scan( - resume_at, - |resume_point, event| { - match event.data { - types::ChannelEventData::Deleted(_) => resume_point.forget(&event), - _ => resume_point.advance(&event), - } - - let event = types::ResumableEvent(resume_point.clone(), event); - - future::ready(Some(event)) - }, - )) + Ok(replay.chain(live_messages)) } fn resume( - resume_at: ResumePoint, + resume_at: Option, ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { - move |event| future::ready(resume_at.not_after(event)) + move |event| future::ready(resume_at < Some(event.sequence)) } } diff --git a/src/events/repo/message.rs b/src/events/repo/message.rs index f8bae2b..3237553 100644 --- a/src/events/repo/message.rs +++ b/src/events/repo/message.rs @@ -2,11 +2,12 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ clock::DateTime, - events::types::{self, Sequence}, + events::types, repo::{ channel::{self, Channel}, login::{self, Login}, message::{self, Message}, + sequence::Sequence, }, }; @@ -27,34 +28,33 @@ impl<'c> Events<'c> { &mut self, sender: &Login, channel: &Channel, - body: &str, sent_at: &DateTime, + sent_sequence: Sequence, + body: &str, ) -> Result { - let sequence = self.assign_sequence(channel).await?; - let id = message::Id::generate(); let message = sqlx::query!( r#" insert into message - (id, channel, sequence, sender, body, sent_at) + (id, channel, sender, sent_at, sent_sequence, body) values ($1, $2, $3, $4, $5, $6) returning id as "id: message::Id", - sequence as "sequence: Sequence", sender as "sender: login::Id", - body, - sent_at as "sent_at: DateTime" + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence", + body "#, id, channel.id, - sequence, sender.id, - body, sent_at, + sent_sequence, + body, ) .map(|row| types::ChannelEvent { - sequence: row.sequence, + sequence: row.sent_sequence, at: row.sent_at, data: types::MessageEvent { channel: channel.clone(), @@ -72,28 +72,12 @@ impl<'c> Events<'c> { Ok(message) } - pub async fn assign_sequence(&mut self, channel: &Channel) -> Result { - let next = sqlx::query_scalar!( - r#" - update channel - set last_sequence = last_sequence + 1 - where id = $1 - returning last_sequence as "next_sequence: Sequence" - "#, - channel.id, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) - } - - pub async fn delete_expired( + pub async fn delete( &mut self, channel: &Channel, message: &message::Id, - sequence: Sequence, deleted_at: &DateTime, + deleted_sequence: Sequence, ) -> Result { sqlx::query_scalar!( r#" @@ -107,7 +91,7 @@ impl<'c> Events<'c> { .await?; Ok(types::ChannelEvent { - sequence, + sequence: deleted_sequence, at: *deleted_at, data: types::MessageDeletedEvent { channel: channel.clone(), @@ -127,6 +111,7 @@ impl<'c> Events<'c> { channel.id as "channel_id: channel::Id", channel.name as "channel_name", channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", message.id as "message: message::Id" from message join channel on message.channel = channel.id @@ -141,6 +126,7 @@ impl<'c> Events<'c> { id: row.channel_id, name: row.channel_name, created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, }, row.message, ) @@ -153,32 +139,39 @@ impl<'c> Events<'c> { pub async fn replay( &mut self, - channel: &Channel, resume_at: Option, ) -> Result, sqlx::Error> { let events = sqlx::query!( r#" select message.id as "id: message::Id", - sequence as "sequence: Sequence", - login.id as "sender_id: login::Id", - login.name as sender_name, - message.body, - message.sent_at as "sent_at: DateTime" + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", + sender.id as "sender_id: login::Id", + sender.name as sender_name, + message.sent_at as "sent_at: DateTime", + message.sent_sequence as "sent_sequence: Sequence", + message.body from message - join login on message.sender = login.id - where channel = $1 - and coalesce(sequence > $2, true) - order by sequence asc + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where coalesce(message.sent_sequence > $1, true) + order by sent_sequence asc "#, - channel.id, resume_at, ) .map(|row| types::ChannelEvent { - sequence: row.sequence, + sequence: row.sent_sequence, at: row.sent_at, data: types::MessageEvent { - channel: channel.clone(), + channel: Channel { + id: row.channel_id, + name: row.channel_name, + created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, + }, sender: login::Login { id: row.sender_id, name: row.sender_name, diff --git a/src/events/routes.rs b/src/events/routes.rs index f09474c..e3a959f 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -9,14 +9,12 @@ use axum::{ }; use futures::stream::{Stream, StreamExt as _}; -use super::{ - extract::LastEventId, - types::{self, ResumePoint}, -}; +use super::{extract::LastEventId, types}; use crate::{ app::App, error::{Internal, Unauthorized}, login::{app::ValidateError, extract::Identity}, + repo::sequence::Sequence, }; #[cfg(test)] @@ -29,11 +27,9 @@ pub fn router() -> Router { async fn events( State(app): State, identity: Identity, - last_event_id: Option>, -) -> Result + std::fmt::Debug>, EventsError> { - let resume_at = last_event_id - .map(LastEventId::into_inner) - .unwrap_or_default(); + last_event_id: Option>, +) -> Result + std::fmt::Debug>, EventsError> { + let resume_at = last_event_id.map(LastEventId::into_inner); let stream = app.events().subscribe(resume_at).await?; let stream = app.logins().limit_stream(identity.token, stream).await?; @@ -46,7 +42,7 @@ struct Events(S); impl IntoResponse for Events where - S: Stream + Send + 'static, + S: Stream + Send + 'static, { fn into_response(self) -> Response { let Self(stream) = self; @@ -57,14 +53,12 @@ where } } -impl TryFrom for sse::Event { +impl TryFrom for sse::Event { type Error = serde_json::Error; - fn try_from(value: types::ResumableEvent) -> Result { - let types::ResumableEvent(resume_at, data) = value; - - let id = serde_json::to_string(&resume_at)?; - let data = serde_json::to_string_pretty(&data)?; + fn try_from(event: types::ChannelEvent) -> Result { + let id = serde_json::to_string(&event.sequence)?; + let data = serde_json::to_string_pretty(&event)?; let event = Self::default().id(id).data(data); diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs index 820192d..1cfca4f 100644 --- a/src/events/routes/test.rs +++ b/src/events/routes/test.rs @@ -5,7 +5,7 @@ use futures::{ }; use crate::{ - events::{routes, types}, + events::routes, test::fixtures::{self, future::Immediately as _}, }; @@ -28,7 +28,7 @@ async fn includes_historical_message() { // Verify the structure of the response. - let types::ResumableEvent(_, event) = events + let event = events .filter(fixtures::filter::messages()) .next() .immediately() @@ -58,7 +58,7 @@ async fn includes_live_message() { let sender = fixtures::login::create(&app).await; let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - let types::ResumableEvent(_, event) = events + let event = events .filter(fixtures::filter::messages()) .next() .immediately() @@ -108,9 +108,7 @@ async fn includes_multiple_channels() { .await; for message in &messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| { event == message })); + assert!(events.iter().any(|event| { event == message })); } } @@ -138,12 +136,11 @@ async fn sequential_messages() { // Verify the structure of the response. - let mut events = - events.filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))); + let mut events = events.filter(|event| future::ready(messages.contains(event))); // Verify delivery in order for message in &messages { - let types::ResumableEvent(_, event) = events + let event = events .next() .immediately() .await @@ -179,7 +176,7 @@ async fn resumes_from() { .await .expect("subscribe never fails"); - let types::ResumableEvent(last_event_id, event) = events + let event = events .filter(fixtures::filter::messages()) .next() .immediately() @@ -188,7 +185,7 @@ async fn resumes_from() { assert_eq!(initial_message, event); - last_event_id + event.sequence }; // Resume after disconnect @@ -205,9 +202,7 @@ async fn resumes_from() { .await; for message in &later_messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + assert!(events.iter().any(|event| event == message)); } } @@ -259,14 +254,12 @@ async fn serial_resume() { .await; for message in &initial_messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + assert!(events.iter().any(|event| event == message)); } - let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); + let event = events.last().expect("this vec is non-empty"); - id.to_owned() + event.sequence }; // Resume after disconnect @@ -296,14 +289,12 @@ async fn serial_resume() { .await; for message in &resume_messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + assert!(events.iter().any(|event| event == message)); } - let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); + let event = events.last().expect("this vec is non-empty"); - id.to_owned() + event.sequence }; // Resume after disconnect a second time @@ -335,9 +326,7 @@ async fn serial_resume() { // This set of messages, in particular, _should not_ include any prior // messages from `initial_messages` or `resume_messages`. for message in &final_messages { - assert!(events - .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + assert!(events.iter().any(|event| event == message)); } }; } @@ -375,7 +364,7 @@ async fn terminates_on_token_expiry() { ]; assert!(events - .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))) + .filter(|event| future::ready(messages.contains(event))) .next() .immediately() .await @@ -417,7 +406,7 @@ async fn terminates_on_logout() { ]; assert!(events - .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))) + .filter(|event| future::ready(messages.contains(event))) .next() .immediately() .await diff --git a/src/events/types.rs b/src/events/types.rs index d954512..aca3af4 100644 --- a/src/events/types.rs +++ b/src/events/types.rs @@ -1,84 +1,13 @@ -use std::collections::BTreeMap; - use crate::{ clock::DateTime, repo::{ channel::{self, Channel}, login::Login, message, + sequence::Sequence, }, }; -#[derive( - Debug, - Default, - Eq, - Ord, - PartialEq, - PartialOrd, - Clone, - Copy, - serde::Serialize, - serde::Deserialize, - sqlx::Type, -)] -#[serde(transparent)] -#[sqlx(transparent)] -pub struct Sequence(i64); - -impl Sequence { - pub fn next(self) -> Self { - let Self(current) = self; - Self(current + 1) - } -} - -// For the purposes of event replay, a resume point is a vector of resume -// elements. A resume element associates a channel (by ID) with the latest event -// seen in that channel so far. Replaying the event stream can restart at a -// predictable point - hence the name. These values can be serialized and sent -// to the client as JSON dicts, then rehydrated to recover the resume point at a -// later time. -// -// Using a sorted map ensures that there is a canonical representation for -// each resume point. -#[derive(Clone, Debug, Default, PartialEq, PartialOrd, serde::Deserialize, serde::Serialize)] -#[serde(transparent)] -pub struct ResumePoint(BTreeMap); - -impl ResumePoint { - pub fn advance<'e>(&mut self, event: impl Into>) { - let Self(elements) = self; - let ResumeElement(channel, sequence) = event.into(); - elements.insert(channel.clone(), sequence); - } - - pub fn forget<'e>(&mut self, event: impl Into>) { - let Self(elements) = self; - let ResumeElement(channel, _) = event.into(); - elements.remove(channel); - } - - pub fn get(&self, channel: &channel::Id) -> Option { - let Self(elements) = self; - elements.get(channel).copied() - } - - pub fn not_after<'e>(&self, event: impl Into>) -> bool { - let Self(elements) = self; - let ResumeElement(channel, sequence) = event.into(); - - elements - .get(channel) - .map_or(true, |resume_at| resume_at < &sequence) - } -} - -pub struct ResumeElement<'i>(&'i channel::Id, Sequence); - -#[derive(Clone, Debug)] -pub struct ResumableEvent(pub ResumePoint, pub ChannelEvent); - #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct ChannelEvent { #[serde(skip)] @@ -92,7 +21,7 @@ impl ChannelEvent { pub fn created(channel: Channel) -> Self { Self { at: channel.created_at, - sequence: Sequence::default(), + sequence: channel.created_sequence, data: CreatedEvent { channel }.into(), } } @@ -107,9 +36,9 @@ impl ChannelEvent { } } -impl<'c> From<&'c ChannelEvent> for ResumeElement<'c> { +impl<'c> From<&'c ChannelEvent> for Sequence { fn from(event: &'c ChannelEvent) -> Self { - Self(event.channel_id(), event.sequence) + event.sequence } } diff --git a/src/repo/channel.rs b/src/repo/channel.rs index 3c7468f..efc2ced 100644 --- a/src/repo/channel.rs +++ b/src/repo/channel.rs @@ -2,9 +2,10 @@ use std::fmt; use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; +use super::sequence::Sequence; use crate::{ clock::DateTime, - events::types::{self, Sequence}, + events::types::{self}, id::Id as BaseId, }; @@ -26,6 +27,8 @@ pub struct Channel { pub name: String, #[serde(skip)] pub created_at: DateTime, + #[serde(skip)] + pub created_sequence: Sequence, } impl<'c> Channels<'c> { @@ -33,25 +36,25 @@ impl<'c> Channels<'c> { &mut self, name: &str, created_at: &DateTime, + created_sequence: Sequence, ) -> Result { let id = Id::generate(); - let sequence = Sequence::default(); - let channel = sqlx::query_as!( Channel, r#" insert - into channel (id, name, created_at, last_sequence) + into channel (id, name, created_at, created_sequence) values ($1, $2, $3, $4) returning id as "id: Id", name, - created_at as "created_at: DateTime" + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" "#, id, name, created_at, - sequence, + created_sequence, ) .fetch_one(&mut *self.0) .await?; @@ -66,7 +69,8 @@ impl<'c> Channels<'c> { select id as "id: Id", name, - created_at as "created_at: DateTime" + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" from channel where id = $1 "#, @@ -85,7 +89,8 @@ impl<'c> Channels<'c> { select id as "id: Id", name, - created_at as "created_at: DateTime" + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" from channel order by channel.name "#, @@ -96,11 +101,34 @@ impl<'c> Channels<'c> { Ok(channels) } - pub async fn delete_expired( + pub async fn replay( + &mut self, + resume_at: Option, + ) -> Result, sqlx::Error> { + let channels = sqlx::query_as!( + Channel, + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where coalesce(created_sequence > $1, true) + "#, + resume_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } + + pub async fn delete( &mut self, channel: &Channel, - sequence: Sequence, deleted_at: &DateTime, + deleted_sequence: Sequence, ) -> Result { let channel = channel.id.clone(); sqlx::query_scalar!( @@ -115,7 +143,7 @@ impl<'c> Channels<'c> { .await?; Ok(types::ChannelEvent { - sequence, + sequence: deleted_sequence, at: *deleted_at, data: types::DeletedEvent { channel }.into(), }) @@ -128,7 +156,8 @@ impl<'c> Channels<'c> { select channel.id as "id: Id", channel.name, - channel.created_at as "created_at: DateTime" + channel.created_at as "created_at: DateTime", + channel.created_sequence as "created_sequence: Sequence" from channel left join message where created_at < $1 diff --git a/src/repo/mod.rs b/src/repo/mod.rs index cb9d7c8..8f271f4 100644 --- a/src/repo/mod.rs +++ b/src/repo/mod.rs @@ -3,4 +3,5 @@ pub mod error; pub mod login; pub mod message; pub mod pool; +pub mod sequence; pub mod token; diff --git a/src/repo/sequence.rs b/src/repo/sequence.rs new file mode 100644 index 0000000..8fe9dab --- /dev/null +++ b/src/repo/sequence.rs @@ -0,0 +1,45 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +pub trait Provider { + fn sequence(&mut self) -> Sequences; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn sequence(&mut self) -> Sequences { + Sequences(self) + } +} + +#[derive( + Clone, + Copy, + Debug, + Eq, + Ord, + PartialEq, + PartialOrd, + serde::Deserialize, + serde::Serialize, + sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Sequence(i64); + +pub struct Sequences<'t>(&'t mut SqliteConnection); + +impl<'c> Sequences<'c> { + pub async fn next(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + update event_sequence + set last_value = last_value + 1 + returning last_value as "next_value: Sequence" + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } +} diff --git a/src/test/fixtures/filter.rs b/src/test/fixtures/filter.rs index fbebced..c31fa58 100644 --- a/src/test/fixtures/filter.rs +++ b/src/test/fixtures/filter.rs @@ -2,14 +2,10 @@ use futures::future; use crate::events::types; -pub fn messages() -> impl FnMut(&types::ResumableEvent) -> future::Ready { - |types::ResumableEvent(_, event)| { - future::ready(matches!(event.data, types::ChannelEventData::Message(_))) - } +pub fn messages() -> impl FnMut(&types::ChannelEvent) -> future::Ready { + |event| future::ready(matches!(event.data, types::ChannelEventData::Message(_))) } -pub fn created() -> impl FnMut(&types::ResumableEvent) -> future::Ready { - |types::ResumableEvent(_, event)| { - future::ready(matches!(event.data, types::ChannelEventData::Created(_))) - } +pub fn created() -> impl FnMut(&types::ChannelEvent) -> future::Ready { + |event| future::ready(matches!(event.data, types::ChannelEventData::Created(_))) } -- cgit v1.2.3 From d171a258ad2119e39cb715f8800031fff16967dc Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Tue, 1 Oct 2024 22:43:18 -0400 Subject: Provide a resume point to bridge clients from state snapshots to the event sequence. --- ...a13fa4f719d82d465e4525557698914a661d39cdb4.json | 20 +++++++ ...8a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json | 38 ------------ ...8143f6b5d16dbeb19ad13ac36dcb40851f0af238e8.json | 38 ++++++++++++ docs/api.md | 13 ++++- src/channel/app.rs | 6 +- src/channel/routes.rs | 15 ++++- src/channel/routes/test/list.rs | 7 ++- src/channel/routes/test/on_create.rs | 2 +- src/events/routes.rs | 11 +++- src/events/routes/test.rs | 67 +++++++++++++++------- src/login/app.rs | 9 +++ src/login/routes.rs | 9 ++- src/login/routes/test/boot.rs | 7 ++- src/repo/channel.rs | 7 ++- src/repo/sequence.rs | 52 ++++++++++++----- 15 files changed, 211 insertions(+), 90 deletions(-) create mode 100644 .sqlx/query-566ee1b8e4e66e78b28675a13fa4f719d82d465e4525557698914a661d39cdb4.json delete mode 100644 .sqlx/query-7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json create mode 100644 .sqlx/query-cda3a4a974eb986ebe26838143f6b5d16dbeb19ad13ac36dcb40851f0af238e8.json diff --git a/.sqlx/query-566ee1b8e4e66e78b28675a13fa4f719d82d465e4525557698914a661d39cdb4.json b/.sqlx/query-566ee1b8e4e66e78b28675a13fa4f719d82d465e4525557698914a661d39cdb4.json new file mode 100644 index 0000000..8d2fc72 --- /dev/null +++ b/.sqlx/query-566ee1b8e4e66e78b28675a13fa4f719d82d465e4525557698914a661d39cdb4.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n select last_value as \"last_value: Sequence\"\n from event_sequence\n ", + "describe": { + "columns": [ + { + "name": "last_value: Sequence", + "ordinal": 0, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 0 + }, + "nullable": [ + false + ] + }, + "hash": "566ee1b8e4e66e78b28675a13fa4f719d82d465e4525557698914a661d39cdb4" +} diff --git a/.sqlx/query-7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json b/.sqlx/query-7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json deleted file mode 100644 index 3cc33cf..0000000 --- a/.sqlx/query-7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\",\n created_sequence as \"created_sequence: Sequence\"\n from channel\n order by channel.name\n ", - "describe": { - "columns": [ - { - "name": "id: Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "created_sequence: Sequence", - "ordinal": 3, - "type_info": "Integer" - } - ], - "parameters": { - "Right": 0 - }, - "nullable": [ - false, - false, - false, - false - ] - }, - "hash": "7f6b9c7d4ef3f540d594318a7a66fa8f9e3ddcf6d041be8d834db58f66a5aa88" -} diff --git a/.sqlx/query-cda3a4a974eb986ebe26838143f6b5d16dbeb19ad13ac36dcb40851f0af238e8.json b/.sqlx/query-cda3a4a974eb986ebe26838143f6b5d16dbeb19ad13ac36dcb40851f0af238e8.json new file mode 100644 index 0000000..bce6a88 --- /dev/null +++ b/.sqlx/query-cda3a4a974eb986ebe26838143f6b5d16dbeb19ad13ac36dcb40851f0af238e8.json @@ -0,0 +1,38 @@ +{ + "db_name": "SQLite", + "query": "\n select\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\",\n created_sequence as \"created_sequence: Sequence\"\n from channel\n where coalesce(created_sequence <= $1, true)\n order by channel.name\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "cda3a4a974eb986ebe26838143f6b5d16dbeb19ad13ac36dcb40851f0af238e8" +} diff --git a/docs/api.md b/docs/api.md index e18c6d5..5adf28d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -23,7 +23,8 @@ Returns information needed to boot the client. Also the recommended way to check "login": { "name": "example username", "id": "L1234abcd", - } + }, + "resume_point": "1312", } ``` @@ -80,6 +81,10 @@ Channels are the containers for conversations. The API supports listing channels Lists channels. +#### Query parameters + +This endpoint accepts an optional `resume_point` query parameter. If provided, the value must be the value obtained from the `/api/boot` method. This parameter will restrict the returned list to channels as they existed at a fixed point in time, with any later changes only appearing in the event stream. + #### On success Responds with a list of channel objects, one per channel: @@ -152,9 +157,13 @@ Subscribes to events. This endpoint returns an `application/event-stream` respon The returned stream may terminate, to limit the number of outstanding messages held by the server. Clients can and should repeat the request, using the `Last-Event-Id` header to resume from where they left off. Events will be replayed from that point, and the stream will resume. +#### Query parameters + +This endpoint accepts an optional `resume_point` query parameter. If provided, the value must be the value obtained from the `/api/boot` method. This parameter start the returned stream immediately after the `resume_point`. + #### Request headers -This endpoint accepts an optional `Last-Event-Id` header for resuming an interrupted stream. If this header is provided, it must be set to the `id` field sent with the last event the client has processed. When `Last-Event-Id` is sent, the response will resume immediately after the corresponding event. If this header is omitted, then the stream will start from the beginning. +This endpoint accepts an optional `Last-Event-Id` header for resuming an interrupted stream. If this header is provided, it must be set to the `id` field sent with the last event the client has processed. When `Last-Event-Id` is sent, the response will resume immediately after the corresponding event. This header takes precedence over the `resume_point` query parameter; if neither is provided, then event playback starts at the beginning of time (_you have been warned_). If you're using a browser's `EventSource` API, this is handled for you automatically. diff --git a/src/channel/app.rs b/src/channel/app.rs index 88f4170..d89e733 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -6,7 +6,7 @@ use crate::{ events::{broadcaster::Broadcaster, types::ChannelEvent}, repo::{ channel::{Channel, Provider as _}, - sequence::Provider as _, + sequence::{Provider as _, Sequence}, }, }; @@ -36,9 +36,9 @@ impl<'a> Channels<'a> { Ok(channel) } - pub async fn all(&self) -> Result, InternalError> { + pub async fn all(&self, resume_point: Option) -> Result, InternalError> { let mut tx = self.db.begin().await?; - let channels = tx.channels().all().await?; + let channels = tx.channels().all(resume_point).await?; tx.commit().await?; Ok(channels) diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 1f8db5a..067d213 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -5,6 +5,7 @@ use axum::{ routing::{get, post}, Router, }; +use axum_extra::extract::Query; use super::app; use crate::{ @@ -15,6 +16,7 @@ use crate::{ repo::{ channel::{self, Channel}, login::Login, + sequence::Sequence, }, }; @@ -28,8 +30,17 @@ pub fn router() -> Router { .route("/api/channels/:channel", post(on_send)) } -async fn list(State(app): State, _: Login) -> Result { - let channels = app.channels().all().await?; +#[derive(Default, serde::Deserialize)] +struct ListQuery { + resume_point: Option, +} + +async fn list( + State(app): State, + _: Login, + Query(query): Query, +) -> Result { + let channels = app.channels().all(query.resume_point).await?; let response = Channels(channels); Ok(response) diff --git a/src/channel/routes/test/list.rs b/src/channel/routes/test/list.rs index bc94024..f15a53c 100644 --- a/src/channel/routes/test/list.rs +++ b/src/channel/routes/test/list.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use axum_extra::extract::Query; use crate::{channel::routes, test::fixtures}; @@ -11,7 +12,7 @@ async fn empty_list() { // Call the endpoint - let routes::Channels(channels) = routes::list(State(app), viewer) + let routes::Channels(channels) = routes::list(State(app), viewer, Query::default()) .await .expect("always succeeds"); @@ -30,7 +31,7 @@ async fn one_channel() { // Call the endpoint - let routes::Channels(channels) = routes::list(State(app), viewer) + let routes::Channels(channels) = routes::list(State(app), viewer, Query::default()) .await .expect("always succeeds"); @@ -52,7 +53,7 @@ async fn multiple_channels() { // Call the endpoint - let routes::Channels(response_channels) = routes::list(State(app), viewer) + let routes::Channels(response_channels) = routes::list(State(app), viewer, Query::default()) .await .expect("always succeeds"); diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index 5deb88a..72980ac 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -33,7 +33,7 @@ async fn new_channel() { // Verify the semantics - let channels = app.channels().all().await.expect("always succeeds"); + let channels = app.channels().all(None).await.expect("always succeeds"); assert!(channels.contains(&response_channel)); let mut events = app diff --git a/src/events/routes.rs b/src/events/routes.rs index e3a959f..d81c7fb 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -7,6 +7,7 @@ use axum::{ routing::get, Router, }; +use axum_extra::extract::Query; use futures::stream::{Stream, StreamExt as _}; use super::{extract::LastEventId, types}; @@ -24,12 +25,20 @@ pub fn router() -> Router { Router::new().route("/api/events", get(events)) } +#[derive(Default, serde::Deserialize)] +struct EventsQuery { + resume_point: Option, +} + async fn events( State(app): State, identity: Identity, last_event_id: Option>, + Query(query): Query, ) -> Result + std::fmt::Debug>, EventsError> { - let resume_at = last_event_id.map(LastEventId::into_inner); + let resume_at = last_event_id + .map(LastEventId::into_inner) + .or(query.resume_point); let stream = app.events().subscribe(resume_at).await?; let stream = app.logins().limit_stream(identity.token, stream).await?; diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs index 1cfca4f..11f01b8 100644 --- a/src/events/routes/test.rs +++ b/src/events/routes/test.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use axum_extra::extract::Query; use futures::{ future, stream::{self, StreamExt as _}, @@ -22,7 +23,7 @@ async fn includes_historical_message() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None) + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) .await .expect("subscribe never fails"); @@ -49,9 +50,10 @@ async fn includes_live_message() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); // Verify the semantics @@ -94,7 +96,7 @@ async fn includes_multiple_channels() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None) + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) .await .expect("subscribe never fails"); @@ -130,7 +132,7 @@ async fn sequential_messages() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None) + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) .await .expect("subscribe never fails"); @@ -172,9 +174,14 @@ async fn resumes_from() { let resume_at = { // First subscription - let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); let event = events .filter(fixtures::filter::messages()) @@ -189,9 +196,14 @@ async fn resumes_from() { }; // Resume after disconnect - let routes::Events(resumed) = routes::events(State(app), subscriber, Some(resume_at.into())) - .await - .expect("subscribe never fails"); + let routes::Events(resumed) = routes::events( + State(app), + subscriber, + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); // Verify the structure of the response. @@ -242,9 +254,14 @@ async fn serial_resume() { ]; // First subscription - let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); let events = events .filter(fixtures::filter::messages()) @@ -277,6 +294,7 @@ async fn serial_resume() { State(app.clone()), subscriber.clone(), Some(resume_at.into()), + Query::default(), ) .await .expect("subscribe never fails"); @@ -312,6 +330,7 @@ async fn serial_resume() { State(app.clone()), subscriber.clone(), Some(resume_at.into()), + Query::default(), ) .await .expect("subscribe never fails"); @@ -345,9 +364,10 @@ async fn terminates_on_token_expiry() { let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; - let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); // Verify the resulting stream's behaviour @@ -387,9 +407,14 @@ async fn terminates_on_logout() { let subscriber = fixtures::identity::from_token(&app, &subscriber_token, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); // Verify the resulting stream's behaviour diff --git a/src/login/app.rs b/src/login/app.rs index 95f0a07..f1dffb9 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -13,6 +13,7 @@ use crate::{ repo::{ error::NotFound as _, login::{Login, Provider as _}, + sequence::{Provider as _, Sequence}, token::{self, Provider as _}, }, }; @@ -27,6 +28,14 @@ impl<'a> Logins<'a> { Self { db, logins } } + pub async fn boot_point(&self) -> Result { + let mut tx = self.db.begin().await?; + let sequence = tx.sequence().current().await?; + tx.commit().await?; + + Ok(sequence) + } + pub async fn login( &self, name: &str, diff --git a/src/login/routes.rs b/src/login/routes.rs index d7cb9b1..ef75871 100644 --- a/src/login/routes.rs +++ b/src/login/routes.rs @@ -26,13 +26,18 @@ pub fn router() -> Router { .route("/api/auth/logout", post(on_logout)) } -async fn boot(login: Login) -> Boot { - Boot { login } +async fn boot(State(app): State, login: Login) -> Result { + let resume_point = app.logins().boot_point().await?; + Ok(Boot { + login, + resume_point: resume_point.to_string(), + }) } #[derive(serde::Serialize)] struct Boot { login: Login, + resume_point: String, } impl IntoResponse for Boot { diff --git a/src/login/routes/test/boot.rs b/src/login/routes/test/boot.rs index dee554f..9655354 100644 --- a/src/login/routes/test/boot.rs +++ b/src/login/routes/test/boot.rs @@ -1,9 +1,14 @@ +use axum::extract::State; + use crate::{login::routes, test::fixtures}; #[tokio::test] async fn returns_identity() { + let app = fixtures::scratch_app().await; let login = fixtures::login::fictitious(); - let response = routes::boot(login.clone()).await; + let response = routes::boot(State(app), login.clone()) + .await + .expect("boot always succeeds"); assert_eq!(login, response.login); } diff --git a/src/repo/channel.rs b/src/repo/channel.rs index efc2ced..ad42710 100644 --- a/src/repo/channel.rs +++ b/src/repo/channel.rs @@ -82,7 +82,10 @@ impl<'c> Channels<'c> { Ok(channel) } - pub async fn all(&mut self) -> Result, sqlx::Error> { + pub async fn all( + &mut self, + resume_point: Option, + ) -> Result, sqlx::Error> { let channels = sqlx::query_as!( Channel, r#" @@ -92,8 +95,10 @@ impl<'c> Channels<'c> { created_at as "created_at: DateTime", created_sequence as "created_sequence: Sequence" from channel + where coalesce(created_sequence <= $1, true) order by channel.name "#, + resume_point, ) .fetch_all(&mut *self.0) .await?; diff --git a/src/repo/sequence.rs b/src/repo/sequence.rs index 8fe9dab..c47b41c 100644 --- a/src/repo/sequence.rs +++ b/src/repo/sequence.rs @@ -1,3 +1,5 @@ +use std::fmt; + use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; pub trait Provider { @@ -10,6 +12,37 @@ impl<'c> Provider for Transaction<'c, Sqlite> { } } +pub struct Sequences<'t>(&'t mut SqliteConnection); + +impl<'c> Sequences<'c> { + pub async fn next(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + update event_sequence + set last_value = last_value + 1 + returning last_value as "next_value: Sequence" + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } + + pub async fn current(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + select last_value as "last_value: Sequence" + from event_sequence + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } +} + #[derive( Clone, Copy, @@ -26,20 +59,9 @@ impl<'c> Provider for Transaction<'c, Sqlite> { #[sqlx(transparent)] pub struct Sequence(i64); -pub struct Sequences<'t>(&'t mut SqliteConnection); - -impl<'c> Sequences<'c> { - pub async fn next(&mut self) -> Result { - let next = sqlx::query_scalar!( - r#" - update event_sequence - set last_value = last_value + 1 - returning last_value as "next_value: Sequence" - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) +impl fmt::Display for Sequence { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(value) = self; + value.fmt(f) } } -- cgit v1.2.3 From f878f0b5eaa44e8ee8d67cbfd706926ff2119113 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Tue, 1 Oct 2024 23:57:22 -0400 Subject: Organize IDs into top-level namespaces. (This is part of a larger reorganization.) --- src/channel/id.rs | 38 +++++++++++++++++++++++++++++++++++++ src/channel/mod.rs | 3 +++ src/channel/routes.rs | 7 ++----- src/channel/routes/test/on_send.rs | 2 +- src/events/app.rs | 3 ++- src/events/repo/message.rs | 11 ++++------- src/events/types.rs | 11 ++++------- src/lib.rs | 1 + src/login/app.rs | 6 ++++-- src/login/extract.rs | 4 ++-- src/login/id.rs | 24 +++++++++++++++++++++++ src/login/mod.rs | 7 +++++-- src/login/repo/auth.rs | 5 +---- src/login/token/id.rs | 27 ++++++++++++++++++++++++++ src/login/token/mod.rs | 3 +++ src/login/types.rs | 2 +- src/message/id.rs | 27 ++++++++++++++++++++++++++ src/message/mod.rs | 3 +++ src/repo/channel.rs | 39 +------------------------------------- src/repo/login/mod.rs | 2 +- src/repo/login/store.rs | 25 +----------------------- src/repo/message.rs | 28 +-------------------------- src/repo/token.rs | 33 +++++--------------------------- 23 files changed, 161 insertions(+), 150 deletions(-) create mode 100644 src/channel/id.rs create mode 100644 src/login/id.rs create mode 100644 src/login/token/id.rs create mode 100644 src/login/token/mod.rs create mode 100644 src/message/id.rs create mode 100644 src/message/mod.rs diff --git a/src/channel/id.rs b/src/channel/id.rs new file mode 100644 index 0000000..22a2700 --- /dev/null +++ b/src/channel/id.rs @@ -0,0 +1,38 @@ +use std::fmt; + +use crate::id::Id as BaseId; + +// Stable identifier for a [Channel]. Prefixed with `C`. +#[derive( + Clone, + Debug, + Eq, + Hash, + Ord, + PartialEq, + PartialOrd, + sqlx::Type, + serde::Deserialize, + serde::Serialize, +)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("C") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 9f79dbb..3115e98 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,4 +1,7 @@ pub mod app; +mod id; mod routes; pub use self::routes::router; + +pub use self::id::Id; diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 067d213..72d6195 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -10,14 +10,11 @@ use axum_extra::extract::Query; use super::app; use crate::{ app::App, + channel, clock::RequestedAt, error::Internal, events::app::EventsError, - repo::{ - channel::{self, Channel}, - login::Login, - sequence::Sequence, - }, + repo::{channel::Channel, login::Login, sequence::Sequence}, }; #[cfg(test)] diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index d37ed21..987784d 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -2,9 +2,9 @@ use axum::extract::{Json, Path, State}; use futures::stream::StreamExt; use crate::{ + channel, channel::routes, events::{app, types}, - repo::channel, test::fixtures::{self, future::Immediately as _}, }; diff --git a/src/events/app.rs b/src/events/app.rs index c15f11e..1fa2f70 100644 --- a/src/events/app.rs +++ b/src/events/app.rs @@ -12,9 +12,10 @@ use super::{ types::{self, ChannelEvent}, }; use crate::{ + channel, clock::DateTime, repo::{ - channel::{self, Provider as _}, + channel::Provider as _, error::NotFound as _, login::Login, sequence::{Provider as _, Sequence}, diff --git a/src/events/repo/message.rs b/src/events/repo/message.rs index 3237553..00c24b1 100644 --- a/src/events/repo/message.rs +++ b/src/events/repo/message.rs @@ -1,14 +1,11 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ + channel, clock::DateTime, events::types, - repo::{ - channel::{self, Channel}, - login::{self, Login}, - message::{self, Message}, - sequence::Sequence, - }, + login, message, + repo::{channel::Channel, login::Login, message::Message, sequence::Sequence}, }; pub trait Provider { @@ -172,7 +169,7 @@ impl<'c> Events<'c> { created_at: row.channel_created_at, created_sequence: row.channel_created_sequence, }, - sender: login::Login { + sender: Login { id: row.sender_id, name: row.sender_name, }, diff --git a/src/events/types.rs b/src/events/types.rs index aca3af4..762b6e5 100644 --- a/src/events/types.rs +++ b/src/events/types.rs @@ -1,11 +1,8 @@ use crate::{ + channel, clock::DateTime, - repo::{ - channel::{self, Channel}, - login::Login, - message, - sequence::Sequence, - }, + message, + repo::{channel::Channel, login::Login, message::Message, sequence::Sequence}, }; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] @@ -66,7 +63,7 @@ impl From for ChannelEventData { pub struct MessageEvent { pub channel: Channel, pub sender: Login, - pub message: message::Message, + pub message: Message, } impl From for ChannelEventData { diff --git a/src/lib.rs b/src/lib.rs index 271118b..2300071 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ mod events; mod expire; mod id; mod login; +mod message; mod password; mod repo; #[cfg(test)] diff --git a/src/login/app.rs b/src/login/app.rs index f1dffb9..8ea0a91 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -6,7 +6,9 @@ use futures::{ }; use sqlx::sqlite::SqlitePool; -use super::{broadcaster::Broadcaster, extract::IdentitySecret, repo::auth::Provider as _, types}; +use super::{ + broadcaster::Broadcaster, extract::IdentitySecret, repo::auth::Provider as _, token, types, +}; use crate::{ clock::DateTime, password::Password, @@ -14,7 +16,7 @@ use crate::{ error::NotFound as _, login::{Login, Provider as _}, sequence::{Provider as _, Sequence}, - token::{self, Provider as _}, + token::Provider as _, }, }; diff --git a/src/login/extract.rs b/src/login/extract.rs index bfdbe8d..39dd9e4 100644 --- a/src/login/extract.rs +++ b/src/login/extract.rs @@ -11,8 +11,8 @@ use crate::{ app::App, clock::RequestedAt, error::{Internal, Unauthorized}, - login::app::ValidateError, - repo::{login::Login, token}, + login::{app::ValidateError, token}, + repo::login::Login, }; // The usage pattern here - receive the extractor as an argument, return it in diff --git a/src/login/id.rs b/src/login/id.rs new file mode 100644 index 0000000..c46d697 --- /dev/null +++ b/src/login/id.rs @@ -0,0 +1,24 @@ +use crate::id::Id as BaseId; + +// Stable identifier for a [Login]. Prefixed with `L`. +#[derive(Clone, Debug, Eq, PartialEq, sqlx::Type, serde::Serialize)] +#[sqlx(transparent)] +pub struct Id(BaseId); + +impl From for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("L") + } +} + +impl std::fmt::Display for Id { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/login/mod.rs b/src/login/mod.rs index 6ae82ac..0430f4b 100644 --- a/src/login/mod.rs +++ b/src/login/mod.rs @@ -1,8 +1,11 @@ -pub use self::routes::router; - pub mod app; pub mod broadcaster; pub mod extract; +mod id; mod repo; mod routes; +pub mod token; pub mod types; + +pub use self::id::Id; +pub use self::routes::router; diff --git a/src/login/repo/auth.rs b/src/login/repo/auth.rs index 3033c8f..9816c5c 100644 --- a/src/login/repo/auth.rs +++ b/src/login/repo/auth.rs @@ -1,9 +1,6 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use crate::{ - password::StoredHash, - repo::login::{self, Login}, -}; +use crate::{login, password::StoredHash, repo::login::Login}; pub trait Provider { fn auth(&mut self) -> Auth; diff --git a/src/login/token/id.rs b/src/login/token/id.rs new file mode 100644 index 0000000..9ef063c --- /dev/null +++ b/src/login/token/id.rs @@ -0,0 +1,27 @@ +use std::fmt; + +use crate::id::Id as BaseId; + +// Stable identifier for a token. Prefixed with `T`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("T") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/login/token/mod.rs b/src/login/token/mod.rs new file mode 100644 index 0000000..d563a88 --- /dev/null +++ b/src/login/token/mod.rs @@ -0,0 +1,3 @@ +mod id; + +pub use self::id::Id; diff --git a/src/login/types.rs b/src/login/types.rs index 7c7cbf9..a210977 100644 --- a/src/login/types.rs +++ b/src/login/types.rs @@ -1,4 +1,4 @@ -use crate::repo::token; +use crate::login::token; #[derive(Clone, Debug)] pub struct TokenRevoked { diff --git a/src/message/id.rs b/src/message/id.rs new file mode 100644 index 0000000..385b103 --- /dev/null +++ b/src/message/id.rs @@ -0,0 +1,27 @@ +use std::fmt; + +use crate::id::Id as BaseId; + +// Stable identifier for a [Message]. Prefixed with `M`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("M") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/message/mod.rs b/src/message/mod.rs new file mode 100644 index 0000000..d563a88 --- /dev/null +++ b/src/message/mod.rs @@ -0,0 +1,3 @@ +mod id; + +pub use self::id::Id; diff --git a/src/repo/channel.rs b/src/repo/channel.rs index ad42710..9f1d930 100644 --- a/src/repo/channel.rs +++ b/src/repo/channel.rs @@ -1,12 +1,10 @@ -use std::fmt; - use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use super::sequence::Sequence; use crate::{ + channel::Id, clock::DateTime, events::types::{self}, - id::Id as BaseId, }; pub trait Provider { @@ -176,38 +174,3 @@ impl<'c> Channels<'c> { Ok(channels) } } - -// Stable identifier for a [Channel]. Prefixed with `C`. -#[derive( - Clone, - Debug, - Eq, - Hash, - Ord, - PartialEq, - PartialOrd, - sqlx::Type, - serde::Deserialize, - serde::Serialize, -)] -#[sqlx(transparent)] -#[serde(transparent)] -pub struct Id(BaseId); - -impl From for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("C") - } -} - -impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} diff --git a/src/repo/login/mod.rs b/src/repo/login/mod.rs index a1b4c6f..4ff7a96 100644 --- a/src/repo/login/mod.rs +++ b/src/repo/login/mod.rs @@ -1,4 +1,4 @@ mod extract; mod store; -pub use self::store::{Id, Login, Provider}; +pub use self::store::{Login, Provider}; diff --git a/src/repo/login/store.rs b/src/repo/login/store.rs index b485941..47d1a7c 100644 --- a/src/repo/login/store.rs +++ b/src/repo/login/store.rs @@ -1,6 +1,6 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use crate::{id::Id as BaseId, password::StoredHash}; +use crate::{login::Id, password::StoredHash}; pub trait Provider { fn logins(&mut self) -> Logins; @@ -61,26 +61,3 @@ impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { Self(tx) } } - -// Stable identifier for a [Login]. Prefixed with `L`. -#[derive(Clone, Debug, Eq, PartialEq, sqlx::Type, serde::Serialize)] -#[sqlx(transparent)] -pub struct Id(BaseId); - -impl From for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("L") - } -} - -impl std::fmt::Display for Id { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} diff --git a/src/repo/message.rs b/src/repo/message.rs index a1f73d5..acde3ea 100644 --- a/src/repo/message.rs +++ b/src/repo/message.rs @@ -1,30 +1,4 @@ -use std::fmt; - -use crate::id::Id as BaseId; - -// Stable identifier for a [Message]. Prefixed with `M`. -#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] -#[sqlx(transparent)] -#[serde(transparent)] -pub struct Id(BaseId); - -impl From for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("M") - } -} - -impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} +use crate::message::Id; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Message { diff --git a/src/repo/token.rs b/src/repo/token.rs index 1663f5e..79e5c54 100644 --- a/src/repo/token.rs +++ b/src/repo/token.rs @@ -1,10 +1,11 @@ -use std::fmt; - use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use uuid::Uuid; -use super::login::{self, Login}; -use crate::{clock::DateTime, id::Id as BaseId, login::extract::IdentitySecret}; +use super::login::Login; +use crate::{ + clock::DateTime, + login::{self, extract::IdentitySecret, token::Id}, +}; pub trait Provider { fn tokens(&mut self) -> Tokens; @@ -148,27 +149,3 @@ impl<'c> Tokens<'c> { Ok(login) } } - -// Stable identifier for a token. Prefixed with `T`. -#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] -#[sqlx(transparent)] -#[serde(transparent)] -pub struct Id(BaseId); - -impl From for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("T") - } -} - -impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} -- cgit v1.2.3 From 357116366c1307bedaac6a3dfe9c5ed8e0e0c210 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Wed, 2 Oct 2024 00:41:25 -0400 Subject: First pass on reorganizing the backend. This is primarily renames and repackagings. --- ...4db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json | 20 + ...884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json | 20 - src/app.rs | 2 +- src/channel/app.rs | 9 +- src/channel/mod.rs | 14 +- src/channel/routes.rs | 6 +- src/channel/routes/test/on_create.rs | 2 +- src/channel/routes/test/on_send.rs | 2 +- src/cli.rs | 4 +- src/event/app.rs | 137 +++++++ src/event/broadcaster.rs | 3 + src/event/extract.rs | 85 ++++ src/event/mod.rs | 9 + src/event/repo/message.rs | 188 +++++++++ src/event/repo/mod.rs | 1 + src/event/routes.rs | 93 +++++ src/event/routes/test.rs | 439 +++++++++++++++++++++ src/event/sequence.rs | 24 ++ src/event/types.rs | 97 +++++ src/events/app.rs | 140 ------- src/events/broadcaster.rs | 3 - src/events/extract.rs | 85 ---- src/events/mod.rs | 8 - src/events/repo/message.rs | 188 --------- src/events/repo/mod.rs | 1 - src/events/routes.rs | 92 ----- src/events/routes/test.rs | 439 --------------------- src/events/types.rs | 96 ----- src/lib.rs | 4 +- src/login/app.rs | 17 +- src/login/extract.rs | 181 +-------- src/login/mod.rs | 18 +- src/login/password.rs | 58 +++ src/login/repo/auth.rs | 2 +- src/login/routes.rs | 6 +- src/login/token/id.rs | 27 -- src/login/token/mod.rs | 3 - src/login/types.rs | 2 +- src/message/mod.rs | 6 + src/password.rs | 58 --- src/repo/channel.rs | 15 +- src/repo/login.rs | 50 +++ src/repo/login/extract.rs | 15 - src/repo/login/mod.rs | 4 - src/repo/login/store.rs | 63 --- src/repo/message.rs | 7 - src/repo/mod.rs | 1 - src/repo/sequence.rs | 27 +- src/repo/token.rs | 10 +- src/test/fixtures/channel.rs | 2 +- src/test/fixtures/filter.rs | 2 +- src/test/fixtures/identity.rs | 9 +- src/test/fixtures/login.rs | 3 +- src/test/fixtures/message.rs | 7 +- src/token/extract/identity.rs | 75 ++++ src/token/extract/identity_token.rs | 94 +++++ src/token/extract/mod.rs | 4 + src/token/id.rs | 27 ++ src/token/mod.rs | 5 + src/token/secret.rs | 27 ++ 60 files changed, 1521 insertions(+), 1515 deletions(-) create mode 100644 .sqlx/query-8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json delete mode 100644 .sqlx/query-e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json create mode 100644 src/event/app.rs create mode 100644 src/event/broadcaster.rs create mode 100644 src/event/extract.rs create mode 100644 src/event/mod.rs create mode 100644 src/event/repo/message.rs create mode 100644 src/event/repo/mod.rs create mode 100644 src/event/routes.rs create mode 100644 src/event/routes/test.rs create mode 100644 src/event/sequence.rs create mode 100644 src/event/types.rs delete mode 100644 src/events/app.rs delete mode 100644 src/events/broadcaster.rs delete mode 100644 src/events/extract.rs delete mode 100644 src/events/mod.rs delete mode 100644 src/events/repo/message.rs delete mode 100644 src/events/repo/mod.rs delete mode 100644 src/events/routes.rs delete mode 100644 src/events/routes/test.rs delete mode 100644 src/events/types.rs create mode 100644 src/login/password.rs delete mode 100644 src/login/token/id.rs delete mode 100644 src/login/token/mod.rs delete mode 100644 src/password.rs create mode 100644 src/repo/login.rs delete mode 100644 src/repo/login/extract.rs delete mode 100644 src/repo/login/mod.rs delete mode 100644 src/repo/login/store.rs delete mode 100644 src/repo/message.rs create mode 100644 src/token/extract/identity.rs create mode 100644 src/token/extract/identity_token.rs create mode 100644 src/token/extract/mod.rs create mode 100644 src/token/id.rs create mode 100644 src/token/mod.rs create mode 100644 src/token/secret.rs diff --git a/.sqlx/query-8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json b/.sqlx/query-8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json new file mode 100644 index 0000000..b433e4c --- /dev/null +++ b/.sqlx/query-8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n insert\n into token (id, secret, login, issued_at, last_used_at)\n values ($1, $2, $3, $4, $4)\n returning secret as \"secret!: Secret\"\n ", + "describe": { + "columns": [ + { + "name": "secret!: Secret", + "ordinal": 0, + "type_info": "Text" + } + ], + "parameters": { + "Right": 4 + }, + "nullable": [ + false + ] + }, + "hash": "8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769" +} diff --git a/.sqlx/query-e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json b/.sqlx/query-e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json deleted file mode 100644 index eda6697..0000000 --- a/.sqlx/query-e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n insert\n into token (id, secret, login, issued_at, last_used_at)\n values ($1, $2, $3, $4, $4)\n returning secret as \"secret!: IdentitySecret\"\n ", - "describe": { - "columns": [ - { - "name": "secret!: IdentitySecret", - "ordinal": 0, - "type_info": "Text" - } - ], - "parameters": { - "Right": 4 - }, - "nullable": [ - false - ] - }, - "hash": "e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669" -} diff --git a/src/app.rs b/src/app.rs index c13f52f..84a6357 100644 --- a/src/app.rs +++ b/src/app.rs @@ -2,7 +2,7 @@ use sqlx::sqlite::SqlitePool; use crate::{ channel::app::Channels, - events::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, + event::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, login::{app::Logins, broadcaster::Broadcaster as LoginBroadcaster}, }; diff --git a/src/channel/app.rs b/src/channel/app.rs index d89e733..1422651 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -2,12 +2,11 @@ use chrono::TimeDelta; use sqlx::sqlite::SqlitePool; use crate::{ + channel::Channel, clock::DateTime, - events::{broadcaster::Broadcaster, types::ChannelEvent}, - repo::{ - channel::{Channel, Provider as _}, - sequence::{Provider as _, Sequence}, - }, + event::Sequence, + event::{broadcaster::Broadcaster, types::ChannelEvent}, + repo::{channel::Provider as _, sequence::Provider as _}, }; pub struct Channels<'a> { diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 3115e98..02d0ed4 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,7 +1,17 @@ +use crate::{clock::DateTime, event::Sequence}; + pub mod app; mod id; mod routes; -pub use self::routes::router; +pub use self::{id::Id, routes::router}; -pub use self::id::Id; +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Channel { + pub id: Id, + pub name: String, + #[serde(skip)] + pub created_at: DateTime, + #[serde(skip)] + pub created_sequence: Sequence, +} diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 72d6195..5d8b61e 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -10,11 +10,11 @@ use axum_extra::extract::Query; use super::app; use crate::{ app::App, - channel, + channel::{self, Channel}, clock::RequestedAt, error::Internal, - events::app::EventsError, - repo::{channel::Channel, login::Login, sequence::Sequence}, + event::{app::EventsError, Sequence}, + login::Login, }; #[cfg(test)] diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index 72980ac..9988932 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -3,7 +3,7 @@ use futures::stream::StreamExt as _; use crate::{ channel::{app, routes}, - events::types, + event::types, test::fixtures::{self, future::Immediately as _}, }; diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 987784d..6f844cd 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -4,7 +4,7 @@ use futures::stream::StreamExt; use crate::{ channel, channel::routes, - events::{app, types}, + event::{app, types}, test::fixtures::{self, future::Immediately as _}, }; diff --git a/src/cli.rs b/src/cli.rs index 132baf8..ee95ea6 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -10,7 +10,7 @@ use clap::Parser; use sqlx::sqlite::SqlitePool; use tokio::net; -use crate::{app::App, channel, clock, events, expire, login, repo::pool}; +use crate::{app::App, channel, clock, event, expire, login, repo::pool}; /// Command-line entry point for running the `hi` server. /// @@ -105,7 +105,7 @@ impl Args { } fn routers() -> Router { - [channel::router(), events::router(), login::router()] + [channel::router(), event::router(), login::router()] .into_iter() .fold(Router::default(), Router::merge) } diff --git a/src/event/app.rs b/src/event/app.rs new file mode 100644 index 0000000..b5f2ecc --- /dev/null +++ b/src/event/app.rs @@ -0,0 +1,137 @@ +use chrono::TimeDelta; +use futures::{ + future, + stream::{self, StreamExt as _}, + Stream, +}; +use sqlx::sqlite::SqlitePool; + +use super::{ + broadcaster::Broadcaster, + repo::message::Provider as _, + types::{self, ChannelEvent}, +}; +use crate::{ + channel, + clock::DateTime, + event::Sequence, + login::Login, + repo::{channel::Provider as _, error::NotFound as _, sequence::Provider as _}, +}; + +pub struct Events<'a> { + db: &'a SqlitePool, + events: &'a Broadcaster, +} + +impl<'a> Events<'a> { + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } + } + + pub async fn send( + &self, + login: &Login, + channel: &channel::Id, + body: &str, + sent_at: &DateTime, + ) -> Result { + let mut tx = self.db.begin().await?; + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; + let sent_sequence = tx.sequence().next().await?; + let event = tx + .message_events() + .create(login, &channel, sent_at, sent_sequence, body) + .await?; + tx.commit().await?; + + self.events.broadcast(&event); + Ok(event) + } + + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 90 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(90); + + let mut tx = self.db.begin().await?; + let expired = tx.message_events().expired(&expire_at).await?; + + let mut events = Vec::with_capacity(expired.len()); + for (channel, message) in expired { + let deleted_sequence = tx.sequence().next().await?; + let event = tx + .message_events() + .delete(&channel, &message, relative_to, deleted_sequence) + .await?; + events.push(event); + } + + tx.commit().await?; + + for event in events { + self.events.broadcast(&event); + } + + Ok(()) + } + + pub async fn subscribe( + &self, + resume_at: Option, + ) -> Result + std::fmt::Debug, sqlx::Error> { + // Subscribe before retrieving, to catch messages broadcast while we're + // querying the DB. We'll prune out duplicates later. + let live_messages = self.events.subscribe(); + + let mut tx = self.db.begin().await?; + let channels = tx.channels().replay(resume_at).await?; + + let channel_events = channels + .into_iter() + .map(ChannelEvent::created) + .filter(move |event| resume_at.map_or(true, |resume_at| event.sequence > resume_at)); + + let message_events = tx.message_events().replay(resume_at).await?; + + let mut replay_events = channel_events + .into_iter() + .chain(message_events.into_iter()) + .collect::>(); + replay_events.sort_by_key(|event| event.sequence); + let resume_live_at = replay_events.last().map(|event| event.sequence); + + let replay = stream::iter(replay_events); + + // no skip_expired or resume transforms for stored_messages, as it's + // constructed not to contain messages meeting either criterion. + // + // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; + // * resume is redundant with the resume_at argument to + // `tx.broadcasts().replay(…)`. + let live_messages = live_messages + // Filtering on the broadcast resume point filters out messages + // before resume_at, and filters out messages duplicated from + // stored_messages. + .filter(Self::resume(resume_live_at)); + + Ok(replay.chain(live_messages)) + } + + fn resume( + resume_at: Option, + ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { + move |event| future::ready(resume_at < Some(event.sequence)) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum EventsError { + #[error("channel {0} not found")] + ChannelNotFound(channel::Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} diff --git a/src/event/broadcaster.rs b/src/event/broadcaster.rs new file mode 100644 index 0000000..92f631f --- /dev/null +++ b/src/event/broadcaster.rs @@ -0,0 +1,3 @@ +use crate::{broadcast, event::types}; + +pub type Broadcaster = broadcast::Broadcaster; diff --git a/src/event/extract.rs b/src/event/extract.rs new file mode 100644 index 0000000..e3021e2 --- /dev/null +++ b/src/event/extract.rs @@ -0,0 +1,85 @@ +use std::ops::Deref; + +use axum::{ + extract::FromRequestParts, + http::{request::Parts, HeaderName, HeaderValue}, +}; +use axum_extra::typed_header::TypedHeader; +use serde::{de::DeserializeOwned, Serialize}; + +// A typed header. When used as a bare extractor, reads from the +// `Last-Event-Id` HTTP header. +pub struct LastEventId(pub T); + +static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id"); + +impl headers::Header for LastEventId +where + T: Serialize + DeserializeOwned, +{ + fn name() -> &'static HeaderName { + &LAST_EVENT_ID + } + + fn decode<'i, I>(values: &mut I) -> Result + where + I: Iterator, + { + let value = values.next().ok_or_else(headers::Error::invalid)?; + let value = value.to_str().map_err(|_| headers::Error::invalid())?; + let value = serde_json::from_str(value).map_err(|_| headers::Error::invalid())?; + Ok(Self(value)) + } + + fn encode(&self, values: &mut E) + where + E: Extend, + { + let Self(value) = self; + // Must panic or suppress; the trait provides no other options. + let value = serde_json::to_string(value).expect("value can be encoded as JSON"); + let value = HeaderValue::from_str(&value).expect("LastEventId is a valid header value"); + + values.extend(std::iter::once(value)); + } +} + +#[async_trait::async_trait] +impl FromRequestParts for LastEventId +where + S: Send + Sync, + T: Serialize + DeserializeOwned, +{ + type Rejection = as FromRequestParts>::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + // This is purely for ergonomics: it allows `RequestedAt` to be extracted + // without having to wrap it in `Extension<>`. Callers _can_ still do that, + // but they aren't forced to. + let TypedHeader(requested_at) = TypedHeader::from_request_parts(parts, state).await?; + + Ok(requested_at) + } +} + +impl Deref for LastEventId { + type Target = T; + + fn deref(&self) -> &Self::Target { + let Self(header) = self; + header + } +} + +impl From for LastEventId { + fn from(value: T) -> Self { + Self(value) + } +} + +impl LastEventId { + pub fn into_inner(self) -> T { + let Self(value) = self; + value + } +} diff --git a/src/event/mod.rs b/src/event/mod.rs new file mode 100644 index 0000000..7ad3f9c --- /dev/null +++ b/src/event/mod.rs @@ -0,0 +1,9 @@ +pub mod app; +pub mod broadcaster; +mod extract; +pub mod repo; +mod routes; +mod sequence; +pub mod types; + +pub use self::{routes::router, sequence::Sequence}; diff --git a/src/event/repo/message.rs b/src/event/repo/message.rs new file mode 100644 index 0000000..f051fec --- /dev/null +++ b/src/event/repo/message.rs @@ -0,0 +1,188 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + channel::{self, Channel}, + clock::DateTime, + event::{types, Sequence}, + login::{self, Login}, + message::{self, Message}, +}; + +pub trait Provider { + fn message_events(&mut self) -> Events; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn message_events(&mut self) -> Events { + Events(self) + } +} + +pub struct Events<'t>(&'t mut SqliteConnection); + +impl<'c> Events<'c> { + pub async fn create( + &mut self, + sender: &Login, + channel: &Channel, + sent_at: &DateTime, + sent_sequence: Sequence, + body: &str, + ) -> Result { + let id = message::Id::generate(); + + let message = sqlx::query!( + r#" + insert into message + (id, channel, sender, sent_at, sent_sequence, body) + values ($1, $2, $3, $4, $5, $6) + returning + id as "id: message::Id", + sender as "sender: login::Id", + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence", + body + "#, + id, + channel.id, + sender.id, + sent_at, + sent_sequence, + body, + ) + .map(|row| types::ChannelEvent { + sequence: row.sent_sequence, + at: row.sent_at, + data: types::MessageEvent { + channel: channel.clone(), + sender: sender.clone(), + message: Message { + id: row.id, + body: row.body, + }, + } + .into(), + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn delete( + &mut self, + channel: &Channel, + message: &message::Id, + deleted_at: &DateTime, + deleted_sequence: Sequence, + ) -> Result { + sqlx::query_scalar!( + r#" + delete from message + where id = $1 + returning 1 as "row: i64" + "#, + message, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(types::ChannelEvent { + sequence: deleted_sequence, + at: *deleted_at, + data: types::MessageDeletedEvent { + channel: channel.clone(), + message: message.clone(), + } + .into(), + }) + } + + pub async fn expired( + &mut self, + expire_at: &DateTime, + ) -> Result, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", + message.id as "message: message::Id" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where sent_at < $1 + "#, + expire_at, + ) + .map(|row| { + ( + Channel { + id: row.channel_id, + name: row.channel_name, + created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, + }, + row.message, + ) + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } + + pub async fn replay( + &mut self, + resume_at: Option, + ) -> Result, sqlx::Error> { + let events = sqlx::query!( + r#" + select + message.id as "id: message::Id", + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", + sender.id as "sender_id: login::Id", + sender.name as sender_name, + message.sent_at as "sent_at: DateTime", + message.sent_sequence as "sent_sequence: Sequence", + message.body + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where coalesce(message.sent_sequence > $1, true) + order by sent_sequence asc + "#, + resume_at, + ) + .map(|row| types::ChannelEvent { + sequence: row.sent_sequence, + at: row.sent_at, + data: types::MessageEvent { + channel: Channel { + id: row.channel_id, + name: row.channel_name, + created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + message: Message { + id: row.id, + body: row.body, + }, + } + .into(), + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(events) + } +} diff --git a/src/event/repo/mod.rs b/src/event/repo/mod.rs new file mode 100644 index 0000000..e216a50 --- /dev/null +++ b/src/event/repo/mod.rs @@ -0,0 +1 @@ +pub mod message; diff --git a/src/event/routes.rs b/src/event/routes.rs new file mode 100644 index 0000000..77761ca --- /dev/null +++ b/src/event/routes.rs @@ -0,0 +1,93 @@ +use axum::{ + extract::State, + response::{ + sse::{self, Sse}, + IntoResponse, Response, + }, + routing::get, + Router, +}; +use axum_extra::extract::Query; +use futures::stream::{Stream, StreamExt as _}; + +use super::{extract::LastEventId, types}; +use crate::{ + app::App, + error::{Internal, Unauthorized}, + event::Sequence, + login::app::ValidateError, + token::extract::Identity, +}; + +#[cfg(test)] +mod test; + +pub fn router() -> Router { + Router::new().route("/api/events", get(events)) +} + +#[derive(Default, serde::Deserialize)] +struct EventsQuery { + resume_point: Option, +} + +async fn events( + State(app): State, + identity: Identity, + last_event_id: Option>, + Query(query): Query, +) -> Result + std::fmt::Debug>, EventsError> { + let resume_at = last_event_id + .map(LastEventId::into_inner) + .or(query.resume_point); + + let stream = app.events().subscribe(resume_at).await?; + let stream = app.logins().limit_stream(identity.token, stream).await?; + + Ok(Events(stream)) +} + +#[derive(Debug)] +struct Events(S); + +impl IntoResponse for Events +where + S: Stream + Send + 'static, +{ + fn into_response(self) -> Response { + let Self(stream) = self; + let stream = stream.map(sse::Event::try_from); + Sse::new(stream) + .keep_alive(sse::KeepAlive::default()) + .into_response() + } +} + +impl TryFrom for sse::Event { + type Error = serde_json::Error; + + fn try_from(event: types::ChannelEvent) -> Result { + let id = serde_json::to_string(&event.sequence)?; + let data = serde_json::to_string_pretty(&event)?; + + let event = Self::default().id(id).data(data); + + Ok(event) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum EventsError { + DatabaseError(#[from] sqlx::Error), + ValidateError(#[from] ValidateError), +} + +impl IntoResponse for EventsError { + fn into_response(self) -> Response { + match self { + Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(), + other => Internal::from(other).into_response(), + } + } +} diff --git a/src/event/routes/test.rs b/src/event/routes/test.rs new file mode 100644 index 0000000..9a3b12a --- /dev/null +++ b/src/event/routes/test.rs @@ -0,0 +1,439 @@ +use axum::extract::State; +use axum_extra::extract::Query; +use futures::{ + future, + stream::{self, StreamExt as _}, +}; + +use crate::{ + event::routes, + test::fixtures::{self, future::Immediately as _}, +}; + +#[tokio::test] +async fn includes_historical_message() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let event = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered stored message"); + + assert_eq!(message, event); +} + +#[tokio::test] +async fn includes_live_message() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the semantics + + let sender = fixtures::login::create(&app).await; + let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + + let event = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered live message"); + + assert_eq!(message, event); +} + +#[tokio::test] +async fn includes_multiple_channels() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + + let channels = [ + fixtures::channel::create(&app, &fixtures::now()).await, + fixtures::channel::create(&app, &fixtures::now()).await, + ]; + + let messages = stream::iter(channels) + .then(|channel| { + let app = app.clone(); + let sender = sender.clone(); + let channel = channel.clone(); + async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } + }) + .collect::>() + .await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let events = events + .filter(fixtures::filter::messages()) + .take(messages.len()) + .collect::>() + .immediately() + .await; + + for message in &messages { + assert!(events.iter().any(|event| { event == message })); + } +} + +#[tokio::test] +async fn sequential_messages() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + let messages = vec![ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let mut events = events.filter(|event| future::ready(messages.contains(event))); + + // Verify delivery in order + for message in &messages { + let event = events + .next() + .immediately() + .await + .expect("undelivered messages remaining"); + + assert_eq!(message, &event); + } +} + +#[tokio::test] +async fn resumes_from() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + + let later_messages = vec![ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + + let resume_at = { + // First subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let event = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered events"); + + assert_eq!(initial_message, event); + + event.sequence + }; + + // Resume after disconnect + let routes::Events(resumed) = routes::events( + State(app), + subscriber, + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let events = resumed + .take(later_messages.len()) + .collect::>() + .immediately() + .await; + + for message in &later_messages { + assert!(events.iter().any(|event| event == message)); + } +} + +// This test verifies a real bug I hit developing the vector-of-sequences +// approach to resuming events. A small omission caused the event IDs in a +// resumed stream to _omit_ channels that were in the original stream until +// those channels also appeared in the resumed stream. +// +// Clients would see something like +// * In the original stream, Cfoo=5,Cbar=8 +// * In the resumed stream, Cfoo=6 (no Cbar sequence number) +// +// Disconnecting and reconnecting a second time, using event IDs from that +// initial period of the first resume attempt, would then cause the second +// resume attempt to restart all other channels from the beginning, and not +// from where the first disconnection happened. +// +// This is a real and valid behaviour for clients! +#[tokio::test] +async fn serial_resume() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + let channel_a = fixtures::channel::create(&app, &fixtures::now()).await; + let channel_b = fixtures::channel::create(&app, &fixtures::now()).await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + + let resume_at = { + let initial_messages = [ + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + ]; + + // First subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let events = events + .filter(fixtures::filter::messages()) + .take(initial_messages.len()) + .collect::>() + .immediately() + .await; + + for message in &initial_messages { + assert!(events.iter().any(|event| event == message)); + } + + let event = events.last().expect("this vec is non-empty"); + + event.sequence + }; + + // Resume after disconnect + let resume_at = { + let resume_messages = [ + // Note that channel_b does not appear here. The buggy behaviour + // would be masked if channel_b happened to send a new message + // into the resumed event stream. + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + ]; + + // Second subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let events = events + .filter(fixtures::filter::messages()) + .take(resume_messages.len()) + .collect::>() + .immediately() + .await; + + for message in &resume_messages { + assert!(events.iter().any(|event| event == message)); + } + + let event = events.last().expect("this vec is non-empty"); + + event.sequence + }; + + // Resume after disconnect a second time + { + // At this point, we can send on either channel and demonstrate the + // problem. The resume point should before both of these messages, but + // after _all_ prior messages. + let final_messages = [ + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + ]; + + // Third subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let events = events + .filter(fixtures::filter::messages()) + .take(final_messages.len()) + .collect::>() + .immediately() + .await; + + // This set of messages, in particular, _should not_ include any prior + // messages from `initial_messages` or `resume_messages`. + for message in &final_messages { + assert!(events.iter().any(|event| event == message)); + } + }; +} + +#[tokio::test] +async fn terminates_on_token_expiry() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + // Subscribe via the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = + fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; + + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the resulting stream's behaviour + + app.logins() + .expire(&fixtures::now()) + .await + .expect("expiring tokens succeeds"); + + // These should not be delivered. + let messages = [ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + assert!(events + .filter(|event| future::ready(messages.contains(event))) + .next() + .immediately() + .await + .is_none()); +} + +#[tokio::test] +async fn terminates_on_logout() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + // Subscribe via the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber_token = + fixtures::identity::logged_in(&app, &subscriber_creds, &fixtures::now()).await; + let subscriber = + fixtures::identity::from_token(&app, &subscriber_token, &fixtures::now()).await; + + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); + + // Verify the resulting stream's behaviour + + app.logins() + .logout(&subscriber.token) + .await + .expect("expiring tokens succeeds"); + + // These should not be delivered. + let messages = [ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + assert!(events + .filter(|event| future::ready(messages.contains(event))) + .next() + .immediately() + .await + .is_none()); +} diff --git a/src/event/sequence.rs b/src/event/sequence.rs new file mode 100644 index 0000000..9ebddd7 --- /dev/null +++ b/src/event/sequence.rs @@ -0,0 +1,24 @@ +use std::fmt; + +#[derive( + Clone, + Copy, + Debug, + Eq, + Ord, + PartialEq, + PartialOrd, + serde::Deserialize, + serde::Serialize, + sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Sequence(i64); + +impl fmt::Display for Sequence { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(value) = self; + value.fmt(f) + } +} diff --git a/src/event/types.rs b/src/event/types.rs new file mode 100644 index 0000000..cd7dea6 --- /dev/null +++ b/src/event/types.rs @@ -0,0 +1,97 @@ +use crate::{ + channel::{self, Channel}, + clock::DateTime, + event::Sequence, + login::Login, + message::{self, Message}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct ChannelEvent { + #[serde(skip)] + pub sequence: Sequence, + pub at: DateTime, + #[serde(flatten)] + pub data: ChannelEventData, +} + +impl ChannelEvent { + pub fn created(channel: Channel) -> Self { + Self { + at: channel.created_at, + sequence: channel.created_sequence, + data: CreatedEvent { channel }.into(), + } + } + + pub fn channel_id(&self) -> &channel::Id { + match &self.data { + ChannelEventData::Created(event) => &event.channel.id, + ChannelEventData::Message(event) => &event.channel.id, + ChannelEventData::MessageDeleted(event) => &event.channel.id, + ChannelEventData::Deleted(event) => &event.channel, + } + } +} + +impl<'c> From<&'c ChannelEvent> for Sequence { + fn from(event: &'c ChannelEvent) -> Self { + event.sequence + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChannelEventData { + Created(CreatedEvent), + Message(MessageEvent), + MessageDeleted(MessageDeletedEvent), + Deleted(DeletedEvent), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct CreatedEvent { + pub channel: Channel, +} + +impl From for ChannelEventData { + fn from(event: CreatedEvent) -> Self { + Self::Created(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct MessageEvent { + pub channel: Channel, + pub sender: Login, + pub message: Message, +} + +impl From for ChannelEventData { + fn from(event: MessageEvent) -> Self { + Self::Message(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct MessageDeletedEvent { + pub channel: Channel, + pub message: message::Id, +} + +impl From for ChannelEventData { + fn from(event: MessageDeletedEvent) -> Self { + Self::MessageDeleted(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct DeletedEvent { + pub channel: channel::Id, +} + +impl From for ChannelEventData { + fn from(event: DeletedEvent) -> Self { + Self::Deleted(event) + } +} diff --git a/src/events/app.rs b/src/events/app.rs deleted file mode 100644 index 1fa2f70..0000000 --- a/src/events/app.rs +++ /dev/null @@ -1,140 +0,0 @@ -use chrono::TimeDelta; -use futures::{ - future, - stream::{self, StreamExt as _}, - Stream, -}; -use sqlx::sqlite::SqlitePool; - -use super::{ - broadcaster::Broadcaster, - repo::message::Provider as _, - types::{self, ChannelEvent}, -}; -use crate::{ - channel, - clock::DateTime, - repo::{ - channel::Provider as _, - error::NotFound as _, - login::Login, - sequence::{Provider as _, Sequence}, - }, -}; - -pub struct Events<'a> { - db: &'a SqlitePool, - events: &'a Broadcaster, -} - -impl<'a> Events<'a> { - pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { - Self { db, events } - } - - pub async fn send( - &self, - login: &Login, - channel: &channel::Id, - body: &str, - sent_at: &DateTime, - ) -> Result { - let mut tx = self.db.begin().await?; - let channel = tx - .channels() - .by_id(channel) - .await - .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; - let sent_sequence = tx.sequence().next().await?; - let event = tx - .message_events() - .create(login, &channel, sent_at, sent_sequence, body) - .await?; - tx.commit().await?; - - self.events.broadcast(&event); - Ok(event) - } - - pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { - // Somewhat arbitrarily, expire after 90 days. - let expire_at = relative_to.to_owned() - TimeDelta::days(90); - - let mut tx = self.db.begin().await?; - let expired = tx.message_events().expired(&expire_at).await?; - - let mut events = Vec::with_capacity(expired.len()); - for (channel, message) in expired { - let deleted_sequence = tx.sequence().next().await?; - let event = tx - .message_events() - .delete(&channel, &message, relative_to, deleted_sequence) - .await?; - events.push(event); - } - - tx.commit().await?; - - for event in events { - self.events.broadcast(&event); - } - - Ok(()) - } - - pub async fn subscribe( - &self, - resume_at: Option, - ) -> Result + std::fmt::Debug, sqlx::Error> { - // Subscribe before retrieving, to catch messages broadcast while we're - // querying the DB. We'll prune out duplicates later. - let live_messages = self.events.subscribe(); - - let mut tx = self.db.begin().await?; - let channels = tx.channels().replay(resume_at).await?; - - let channel_events = channels - .into_iter() - .map(ChannelEvent::created) - .filter(move |event| resume_at.map_or(true, |resume_at| event.sequence > resume_at)); - - let message_events = tx.message_events().replay(resume_at).await?; - - let mut replay_events = channel_events - .into_iter() - .chain(message_events.into_iter()) - .collect::>(); - replay_events.sort_by_key(|event| event.sequence); - let resume_live_at = replay_events.last().map(|event| event.sequence); - - let replay = stream::iter(replay_events); - - // no skip_expired or resume transforms for stored_messages, as it's - // constructed not to contain messages meeting either criterion. - // - // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; - // * resume is redundant with the resume_at argument to - // `tx.broadcasts().replay(…)`. - let live_messages = live_messages - // Filtering on the broadcast resume point filters out messages - // before resume_at, and filters out messages duplicated from - // stored_messages. - .filter(Self::resume(resume_live_at)); - - Ok(replay.chain(live_messages)) - } - - fn resume( - resume_at: Option, - ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { - move |event| future::ready(resume_at < Some(event.sequence)) - } -} - -#[derive(Debug, thiserror::Error)] -pub enum EventsError { - #[error("channel {0} not found")] - ChannelNotFound(channel::Id), - #[error(transparent)] - DatabaseError(#[from] sqlx::Error), -} diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs deleted file mode 100644 index 6b664cb..0000000 --- a/src/events/broadcaster.rs +++ /dev/null @@ -1,3 +0,0 @@ -use crate::{broadcast, events::types}; - -pub type Broadcaster = broadcast::Broadcaster; diff --git a/src/events/extract.rs b/src/events/extract.rs deleted file mode 100644 index e3021e2..0000000 --- a/src/events/extract.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::ops::Deref; - -use axum::{ - extract::FromRequestParts, - http::{request::Parts, HeaderName, HeaderValue}, -}; -use axum_extra::typed_header::TypedHeader; -use serde::{de::DeserializeOwned, Serialize}; - -// A typed header. When used as a bare extractor, reads from the -// `Last-Event-Id` HTTP header. -pub struct LastEventId(pub T); - -static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id"); - -impl headers::Header for LastEventId -where - T: Serialize + DeserializeOwned, -{ - fn name() -> &'static HeaderName { - &LAST_EVENT_ID - } - - fn decode<'i, I>(values: &mut I) -> Result - where - I: Iterator, - { - let value = values.next().ok_or_else(headers::Error::invalid)?; - let value = value.to_str().map_err(|_| headers::Error::invalid())?; - let value = serde_json::from_str(value).map_err(|_| headers::Error::invalid())?; - Ok(Self(value)) - } - - fn encode(&self, values: &mut E) - where - E: Extend, - { - let Self(value) = self; - // Must panic or suppress; the trait provides no other options. - let value = serde_json::to_string(value).expect("value can be encoded as JSON"); - let value = HeaderValue::from_str(&value).expect("LastEventId is a valid header value"); - - values.extend(std::iter::once(value)); - } -} - -#[async_trait::async_trait] -impl FromRequestParts for LastEventId -where - S: Send + Sync, - T: Serialize + DeserializeOwned, -{ - type Rejection = as FromRequestParts>::Rejection; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - // This is purely for ergonomics: it allows `RequestedAt` to be extracted - // without having to wrap it in `Extension<>`. Callers _can_ still do that, - // but they aren't forced to. - let TypedHeader(requested_at) = TypedHeader::from_request_parts(parts, state).await?; - - Ok(requested_at) - } -} - -impl Deref for LastEventId { - type Target = T; - - fn deref(&self) -> &Self::Target { - let Self(header) = self; - header - } -} - -impl From for LastEventId { - fn from(value: T) -> Self { - Self(value) - } -} - -impl LastEventId { - pub fn into_inner(self) -> T { - let Self(value) = self; - value - } -} diff --git a/src/events/mod.rs b/src/events/mod.rs deleted file mode 100644 index 711ae64..0000000 --- a/src/events/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod app; -pub mod broadcaster; -mod extract; -pub mod repo; -mod routes; -pub mod types; - -pub use self::routes::router; diff --git a/src/events/repo/message.rs b/src/events/repo/message.rs deleted file mode 100644 index 00c24b1..0000000 --- a/src/events/repo/message.rs +++ /dev/null @@ -1,188 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - channel, - clock::DateTime, - events::types, - login, message, - repo::{channel::Channel, login::Login, message::Message, sequence::Sequence}, -}; - -pub trait Provider { - fn message_events(&mut self) -> Events; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn message_events(&mut self) -> Events { - Events(self) - } -} - -pub struct Events<'t>(&'t mut SqliteConnection); - -impl<'c> Events<'c> { - pub async fn create( - &mut self, - sender: &Login, - channel: &Channel, - sent_at: &DateTime, - sent_sequence: Sequence, - body: &str, - ) -> Result { - let id = message::Id::generate(); - - let message = sqlx::query!( - r#" - insert into message - (id, channel, sender, sent_at, sent_sequence, body) - values ($1, $2, $3, $4, $5, $6) - returning - id as "id: message::Id", - sender as "sender: login::Id", - sent_at as "sent_at: DateTime", - sent_sequence as "sent_sequence: Sequence", - body - "#, - id, - channel.id, - sender.id, - sent_at, - sent_sequence, - body, - ) - .map(|row| types::ChannelEvent { - sequence: row.sent_sequence, - at: row.sent_at, - data: types::MessageEvent { - channel: channel.clone(), - sender: sender.clone(), - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_one(&mut *self.0) - .await?; - - Ok(message) - } - - pub async fn delete( - &mut self, - channel: &Channel, - message: &message::Id, - deleted_at: &DateTime, - deleted_sequence: Sequence, - ) -> Result { - sqlx::query_scalar!( - r#" - delete from message - where id = $1 - returning 1 as "row: i64" - "#, - message, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(types::ChannelEvent { - sequence: deleted_sequence, - at: *deleted_at, - data: types::MessageDeletedEvent { - channel: channel.clone(), - message: message.clone(), - } - .into(), - }) - } - - pub async fn expired( - &mut self, - expire_at: &DateTime, - ) -> Result, sqlx::Error> { - let messages = sqlx::query!( - r#" - select - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - channel.created_at as "channel_created_at: DateTime", - channel.created_sequence as "channel_created_sequence: Sequence", - message.id as "message: message::Id" - from message - join channel on message.channel = channel.id - join login as sender on message.sender = sender.id - where sent_at < $1 - "#, - expire_at, - ) - .map(|row| { - ( - Channel { - id: row.channel_id, - name: row.channel_name, - created_at: row.channel_created_at, - created_sequence: row.channel_created_sequence, - }, - row.message, - ) - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(messages) - } - - pub async fn replay( - &mut self, - resume_at: Option, - ) -> Result, sqlx::Error> { - let events = sqlx::query!( - r#" - select - message.id as "id: message::Id", - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - channel.created_at as "channel_created_at: DateTime", - channel.created_sequence as "channel_created_sequence: Sequence", - sender.id as "sender_id: login::Id", - sender.name as sender_name, - message.sent_at as "sent_at: DateTime", - message.sent_sequence as "sent_sequence: Sequence", - message.body - from message - join channel on message.channel = channel.id - join login as sender on message.sender = sender.id - where coalesce(message.sent_sequence > $1, true) - order by sent_sequence asc - "#, - resume_at, - ) - .map(|row| types::ChannelEvent { - sequence: row.sent_sequence, - at: row.sent_at, - data: types::MessageEvent { - channel: Channel { - id: row.channel_id, - name: row.channel_name, - created_at: row.channel_created_at, - created_sequence: row.channel_created_sequence, - }, - sender: Login { - id: row.sender_id, - name: row.sender_name, - }, - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(events) - } -} diff --git a/src/events/repo/mod.rs b/src/events/repo/mod.rs deleted file mode 100644 index e216a50..0000000 --- a/src/events/repo/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod message; diff --git a/src/events/routes.rs b/src/events/routes.rs deleted file mode 100644 index d81c7fb..0000000 --- a/src/events/routes.rs +++ /dev/null @@ -1,92 +0,0 @@ -use axum::{ - extract::State, - response::{ - sse::{self, Sse}, - IntoResponse, Response, - }, - routing::get, - Router, -}; -use axum_extra::extract::Query; -use futures::stream::{Stream, StreamExt as _}; - -use super::{extract::LastEventId, types}; -use crate::{ - app::App, - error::{Internal, Unauthorized}, - login::{app::ValidateError, extract::Identity}, - repo::sequence::Sequence, -}; - -#[cfg(test)] -mod test; - -pub fn router() -> Router { - Router::new().route("/api/events", get(events)) -} - -#[derive(Default, serde::Deserialize)] -struct EventsQuery { - resume_point: Option, -} - -async fn events( - State(app): State, - identity: Identity, - last_event_id: Option>, - Query(query): Query, -) -> Result + std::fmt::Debug>, EventsError> { - let resume_at = last_event_id - .map(LastEventId::into_inner) - .or(query.resume_point); - - let stream = app.events().subscribe(resume_at).await?; - let stream = app.logins().limit_stream(identity.token, stream).await?; - - Ok(Events(stream)) -} - -#[derive(Debug)] -struct Events(S); - -impl IntoResponse for Events -where - S: Stream + Send + 'static, -{ - fn into_response(self) -> Response { - let Self(stream) = self; - let stream = stream.map(sse::Event::try_from); - Sse::new(stream) - .keep_alive(sse::KeepAlive::default()) - .into_response() - } -} - -impl TryFrom for sse::Event { - type Error = serde_json::Error; - - fn try_from(event: types::ChannelEvent) -> Result { - let id = serde_json::to_string(&event.sequence)?; - let data = serde_json::to_string_pretty(&event)?; - - let event = Self::default().id(id).data(data); - - Ok(event) - } -} - -#[derive(Debug, thiserror::Error)] -#[error(transparent)] -pub enum EventsError { - DatabaseError(#[from] sqlx::Error), - ValidateError(#[from] ValidateError), -} - -impl IntoResponse for EventsError { - fn into_response(self) -> Response { - match self { - Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(), - other => Internal::from(other).into_response(), - } - } -} diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs deleted file mode 100644 index 11f01b8..0000000 --- a/src/events/routes/test.rs +++ /dev/null @@ -1,439 +0,0 @@ -use axum::extract::State; -use axum_extra::extract::Query; -use futures::{ - future, - stream::{self, StreamExt as _}, -}; - -use crate::{ - events::routes, - test::fixtures::{self, future::Immediately as _}, -}; - -#[tokio::test] -async fn includes_historical_message() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app).await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the structure of the response. - - let event = events - .filter(fixtures::filter::messages()) - .next() - .immediately() - .await - .expect("delivered stored message"); - - assert_eq!(message, event); -} - -#[tokio::test] -async fn includes_live_message() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = - routes::events(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the semantics - - let sender = fixtures::login::create(&app).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - - let event = events - .filter(fixtures::filter::messages()) - .next() - .immediately() - .await - .expect("delivered live message"); - - assert_eq!(message, event); -} - -#[tokio::test] -async fn includes_multiple_channels() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app).await; - - let channels = [ - fixtures::channel::create(&app, &fixtures::now()).await, - fixtures::channel::create(&app, &fixtures::now()).await, - ]; - - let messages = stream::iter(channels) - .then(|channel| { - let app = app.clone(); - let sender = sender.clone(); - let channel = channel.clone(); - async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } - }) - .collect::>() - .await; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the structure of the response. - - let events = events - .filter(fixtures::filter::messages()) - .take(messages.len()) - .collect::>() - .immediately() - .await; - - for message in &messages { - assert!(events.iter().any(|event| { event == message })); - } -} - -#[tokio::test] -async fn sequential_messages() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let sender = fixtures::login::create(&app).await; - - let messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - ]; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the structure of the response. - - let mut events = events.filter(|event| future::ready(messages.contains(event))); - - // Verify delivery in order - for message in &messages { - let event = events - .next() - .immediately() - .await - .expect("undelivered messages remaining"); - - assert_eq!(message, &event); - } -} - -#[tokio::test] -async fn resumes_from() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let sender = fixtures::login::create(&app).await; - - let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - - let later_messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - ]; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - - let resume_at = { - // First subscription - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - None, - Query::default(), - ) - .await - .expect("subscribe never fails"); - - let event = events - .filter(fixtures::filter::messages()) - .next() - .immediately() - .await - .expect("delivered events"); - - assert_eq!(initial_message, event); - - event.sequence - }; - - // Resume after disconnect - let routes::Events(resumed) = routes::events( - State(app), - subscriber, - Some(resume_at.into()), - Query::default(), - ) - .await - .expect("subscribe never fails"); - - // Verify the structure of the response. - - let events = resumed - .take(later_messages.len()) - .collect::>() - .immediately() - .await; - - for message in &later_messages { - assert!(events.iter().any(|event| event == message)); - } -} - -// This test verifies a real bug I hit developing the vector-of-sequences -// approach to resuming events. A small omission caused the event IDs in a -// resumed stream to _omit_ channels that were in the original stream until -// those channels also appeared in the resumed stream. -// -// Clients would see something like -// * In the original stream, Cfoo=5,Cbar=8 -// * In the resumed stream, Cfoo=6 (no Cbar sequence number) -// -// Disconnecting and reconnecting a second time, using event IDs from that -// initial period of the first resume attempt, would then cause the second -// resume attempt to restart all other channels from the beginning, and not -// from where the first disconnection happened. -// -// This is a real and valid behaviour for clients! -#[tokio::test] -async fn serial_resume() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app).await; - let channel_a = fixtures::channel::create(&app, &fixtures::now()).await; - let channel_b = fixtures::channel::create(&app, &fixtures::now()).await; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - - let resume_at = { - let initial_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, - ]; - - // First subscription - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - None, - Query::default(), - ) - .await - .expect("subscribe never fails"); - - let events = events - .filter(fixtures::filter::messages()) - .take(initial_messages.len()) - .collect::>() - .immediately() - .await; - - for message in &initial_messages { - assert!(events.iter().any(|event| event == message)); - } - - let event = events.last().expect("this vec is non-empty"); - - event.sequence - }; - - // Resume after disconnect - let resume_at = { - let resume_messages = [ - // Note that channel_b does not appear here. The buggy behaviour - // would be masked if channel_b happened to send a new message - // into the resumed event stream. - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - ]; - - // Second subscription - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - Some(resume_at.into()), - Query::default(), - ) - .await - .expect("subscribe never fails"); - - let events = events - .filter(fixtures::filter::messages()) - .take(resume_messages.len()) - .collect::>() - .immediately() - .await; - - for message in &resume_messages { - assert!(events.iter().any(|event| event == message)); - } - - let event = events.last().expect("this vec is non-empty"); - - event.sequence - }; - - // Resume after disconnect a second time - { - // At this point, we can send on either channel and demonstrate the - // problem. The resume point should before both of these messages, but - // after _all_ prior messages. - let final_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, - ]; - - // Third subscription - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - Some(resume_at.into()), - Query::default(), - ) - .await - .expect("subscribe never fails"); - - let events = events - .filter(fixtures::filter::messages()) - .take(final_messages.len()) - .collect::>() - .immediately() - .await; - - // This set of messages, in particular, _should not_ include any prior - // messages from `initial_messages` or `resume_messages`. - for message in &final_messages { - assert!(events.iter().any(|event| event == message)); - } - }; -} - -#[tokio::test] -async fn terminates_on_token_expiry() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let sender = fixtures::login::create(&app).await; - - // Subscribe via the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = - fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; - - let routes::Events(events) = - routes::events(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the resulting stream's behaviour - - app.logins() - .expire(&fixtures::now()) - .await - .expect("expiring tokens succeeds"); - - // These should not be delivered. - let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - ]; - - assert!(events - .filter(|event| future::ready(messages.contains(event))) - .next() - .immediately() - .await - .is_none()); -} - -#[tokio::test] -async fn terminates_on_logout() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let sender = fixtures::login::create(&app).await; - - // Subscribe via the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber_token = - fixtures::identity::logged_in(&app, &subscriber_creds, &fixtures::now()).await; - let subscriber = - fixtures::identity::from_token(&app, &subscriber_token, &fixtures::now()).await; - - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - None, - Query::default(), - ) - .await - .expect("subscribe never fails"); - - // Verify the resulting stream's behaviour - - app.logins() - .logout(&subscriber.token) - .await - .expect("expiring tokens succeeds"); - - // These should not be delivered. - let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - ]; - - assert!(events - .filter(|event| future::ready(messages.contains(event))) - .next() - .immediately() - .await - .is_none()); -} diff --git a/src/events/types.rs b/src/events/types.rs deleted file mode 100644 index 762b6e5..0000000 --- a/src/events/types.rs +++ /dev/null @@ -1,96 +0,0 @@ -use crate::{ - channel, - clock::DateTime, - message, - repo::{channel::Channel, login::Login, message::Message, sequence::Sequence}, -}; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct ChannelEvent { - #[serde(skip)] - pub sequence: Sequence, - pub at: DateTime, - #[serde(flatten)] - pub data: ChannelEventData, -} - -impl ChannelEvent { - pub fn created(channel: Channel) -> Self { - Self { - at: channel.created_at, - sequence: channel.created_sequence, - data: CreatedEvent { channel }.into(), - } - } - - pub fn channel_id(&self) -> &channel::Id { - match &self.data { - ChannelEventData::Created(event) => &event.channel.id, - ChannelEventData::Message(event) => &event.channel.id, - ChannelEventData::MessageDeleted(event) => &event.channel.id, - ChannelEventData::Deleted(event) => &event.channel, - } - } -} - -impl<'c> From<&'c ChannelEvent> for Sequence { - fn from(event: &'c ChannelEvent) -> Self { - event.sequence - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ChannelEventData { - Created(CreatedEvent), - Message(MessageEvent), - MessageDeleted(MessageDeletedEvent), - Deleted(DeletedEvent), -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct CreatedEvent { - pub channel: Channel, -} - -impl From for ChannelEventData { - fn from(event: CreatedEvent) -> Self { - Self::Created(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageEvent { - pub channel: Channel, - pub sender: Login, - pub message: Message, -} - -impl From for ChannelEventData { - fn from(event: MessageEvent) -> Self { - Self::Message(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageDeletedEvent { - pub channel: Channel, - pub message: message::Id, -} - -impl From for ChannelEventData { - fn from(event: MessageDeletedEvent) -> Self { - Self::MessageDeleted(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct DeletedEvent { - pub channel: channel::Id, -} - -impl From for ChannelEventData { - fn from(event: DeletedEvent) -> Self { - Self::Deleted(event) - } -} diff --git a/src/lib.rs b/src/lib.rs index 2300071..bbcb314 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,12 +8,12 @@ mod channel; pub mod cli; mod clock; mod error; -mod events; +mod event; mod expire; mod id; mod login; mod message; -mod password; mod repo; #[cfg(test)] mod test; +mod token; diff --git a/src/login/app.rs b/src/login/app.rs index 8ea0a91..60475af 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -6,18 +6,15 @@ use futures::{ }; use sqlx::sqlite::SqlitePool; -use super::{ - broadcaster::Broadcaster, extract::IdentitySecret, repo::auth::Provider as _, token, types, -}; +use super::{broadcaster::Broadcaster, repo::auth::Provider as _, types, Login}; use crate::{ clock::DateTime, - password::Password, + event::Sequence, + login::Password, repo::{ - error::NotFound as _, - login::{Login, Provider as _}, - sequence::{Provider as _, Sequence}, - token::Provider as _, + error::NotFound as _, login::Provider as _, sequence::Provider as _, token::Provider as _, }, + token::{self, Secret}, }; pub struct Logins<'a> { @@ -43,7 +40,7 @@ impl<'a> Logins<'a> { name: &str, password: &Password, login_at: &DateTime, - ) -> Result { + ) -> Result { let mut tx = self.db.begin().await?; let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? { @@ -78,7 +75,7 @@ impl<'a> Logins<'a> { pub async fn validate( &self, - secret: &IdentitySecret, + secret: &Secret, used_at: &DateTime, ) -> Result<(token::Id, Login), ValidateError> { let mut tx = self.db.begin().await?; diff --git a/src/login/extract.rs b/src/login/extract.rs index 39dd9e4..c2d97f2 100644 --- a/src/login/extract.rs +++ b/src/login/extract.rs @@ -1,182 +1,15 @@ -use std::fmt; +use axum::{extract::FromRequestParts, http::request::Parts}; -use axum::{ - extract::{FromRequestParts, State}, - http::request::Parts, - response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, -}; -use axum_extra::extract::cookie::{Cookie, CookieJar}; - -use crate::{ - app::App, - clock::RequestedAt, - error::{Internal, Unauthorized}, - login::{app::ValidateError, token}, - repo::login::Login, -}; - -// The usage pattern here - receive the extractor as an argument, return it in -// the response - is heavily modelled after CookieJar's own intended usage. -#[derive(Clone)] -pub struct IdentityToken { - cookies: CookieJar, -} - -impl fmt::Debug for IdentityToken { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("IdentityToken") - .field( - "identity", - &self.cookies.get(IDENTITY_COOKIE).map(|_| "********"), - ) - .finish() - } -} - -impl IdentityToken { - // Creates a new, unpopulated identity token store. - #[cfg(test)] - pub fn new() -> Self { - Self { - cookies: CookieJar::new(), - } - } - - // Get the identity secret sent in the request, if any. If the identity - // was not sent, or if it has previously been [clear]ed, then this will - // return [None]. If the identity has previously been [set], then this - // will return that secret, regardless of what the request originally - // included. - pub fn secret(&self) -> Option { - self.cookies - .get(IDENTITY_COOKIE) - .map(Cookie::value) - .map(IdentitySecret::from) - } - - // Positively set the identity secret, and ensure that it will be sent - // back to the client when this extractor is included in a response. - pub fn set(self, secret: impl Into) -> Self { - let IdentitySecret(secret) = secret.into(); - let identity_cookie = Cookie::build((IDENTITY_COOKIE, secret)) - .http_only(true) - .path("/api/") - .permanent() - .build(); - - Self { - cookies: self.cookies.add(identity_cookie), - } - } - - // Remove the identity secret and ensure that it will be cleared when this - // extractor is included in a response. - pub fn clear(self) -> Self { - Self { - cookies: self.cookies.remove(IDENTITY_COOKIE), - } - } -} - -const IDENTITY_COOKIE: &str = "identity"; - -#[async_trait::async_trait] -impl FromRequestParts for IdentityToken -where - S: Send + Sync, -{ - type Rejection = >::Rejection; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let cookies = CookieJar::from_request_parts(parts, state).await?; - Ok(Self { cookies }) - } -} - -impl IntoResponseParts for IdentityToken { - type Error = ::Error; - - fn into_response_parts(self, res: ResponseParts) -> Result { - let Self { cookies } = self; - cookies.into_response_parts(res) - } -} - -#[derive(sqlx::Type)] -#[sqlx(transparent)] -pub struct IdentitySecret(String); - -impl fmt::Debug for IdentitySecret { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("IdentityToken").field(&"********").finish() - } -} - -impl From for IdentitySecret -where - S: Into, -{ - fn from(value: S) -> Self { - Self(value.into()) - } -} - -#[derive(Clone, Debug)] -pub struct Identity { - pub token: token::Id, - pub login: Login, -} +use super::Login; +use crate::{app::App, token::extract::Identity}; #[async_trait::async_trait] -impl FromRequestParts for Identity { - type Rejection = LoginError; +impl FromRequestParts for Login { + type Rejection = >::Rejection; async fn from_request_parts(parts: &mut Parts, state: &App) -> Result { - // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on - // stable), the following can be replaced: - // - // ``` - // let Ok(identity_token) = IdentityToken::from_request_parts( - // parts, - // state, - // ).await; - // ``` - let identity_token = IdentityToken::from_request_parts(parts, state).await?; - let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?; - - let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; - - let app = State::::from_request_parts(parts, state).await?; - match app.logins().validate(&secret, &used_at).await { - Ok((token, login)) => Ok(Identity { token, login }), - Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), - Err(other) => Err(other.into()), - } - } -} - -pub enum LoginError { - Failure(E), - Unauthorized, -} - -impl IntoResponse for LoginError -where - E: IntoResponse, -{ - fn into_response(self) -> Response { - match self { - Self::Unauthorized => Unauthorized.into_response(), - Self::Failure(e) => e.into_response(), - } - } -} + let identity = Identity::from_request_parts(parts, state).await?; -impl From for LoginError -where - E: Into, -{ - fn from(err: E) -> Self { - Self::Failure(err.into()) + Ok(identity.login) } } diff --git a/src/login/mod.rs b/src/login/mod.rs index 0430f4b..91c1821 100644 --- a/src/login/mod.rs +++ b/src/login/mod.rs @@ -2,10 +2,22 @@ pub mod app; pub mod broadcaster; pub mod extract; mod id; +pub mod password; mod repo; mod routes; -pub mod token; pub mod types; -pub use self::id::Id; -pub use self::routes::router; +pub use self::{id::Id, password::Password, routes::router}; + +// This also implements FromRequestParts (see `./extract.rs`). As a result, it +// can be used as an extractor for endpoints that want to require login, or for +// endpoints that need to behave differently depending on whether the client is +// or is not logged in. +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Login { + pub id: Id, + pub name: String, + // The omission of the hashed password is deliberate, to minimize the + // chance that it ends up tangled up in debug output or in some other chunk + // of logic elsewhere. +} diff --git a/src/login/password.rs b/src/login/password.rs new file mode 100644 index 0000000..da3930f --- /dev/null +++ b/src/login/password.rs @@ -0,0 +1,58 @@ +use std::fmt; + +use argon2::Argon2; +use password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}; +use rand_core::OsRng; + +#[derive(Debug, sqlx::Type)] +#[sqlx(transparent)] +pub struct StoredHash(String); + +impl StoredHash { + pub fn verify(&self, password: &Password) -> Result { + let hash = PasswordHash::new(&self.0)?; + + match Argon2::default().verify_password(password.as_bytes(), &hash) { + // Successful authentication, not an error + Ok(()) => Ok(true), + // Unsuccessful authentication, also not an error + Err(password_hash::errors::Error::Password) => Ok(false), + // Password validation failed for some other reason, treat as an error + Err(err) => Err(err), + } + } +} + +#[derive(serde::Deserialize)] +#[serde(transparent)] +pub struct Password(String); + +impl Password { + pub fn hash(&self) -> Result { + let Self(password) = self; + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let hash = argon2 + .hash_password(password.as_bytes(), &salt)? + .to_string(); + Ok(StoredHash(hash)) + } + + fn as_bytes(&self) -> &[u8] { + let Self(value) = self; + value.as_bytes() + } +} + +impl fmt::Debug for Password { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Password").field(&"********").finish() + } +} + +#[cfg(test)] +impl From for Password { + fn from(password: String) -> Self { + Self(password) + } +} diff --git a/src/login/repo/auth.rs b/src/login/repo/auth.rs index 9816c5c..b299697 100644 --- a/src/login/repo/auth.rs +++ b/src/login/repo/auth.rs @@ -1,6 +1,6 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use crate::{login, password::StoredHash, repo::login::Login}; +use crate::login::{self, password::StoredHash, Login}; pub trait Provider { fn auth(&mut self) -> Auth; diff --git a/src/login/routes.rs b/src/login/routes.rs index ef75871..b571bd5 100644 --- a/src/login/routes.rs +++ b/src/login/routes.rs @@ -10,11 +10,11 @@ use crate::{ app::App, clock::RequestedAt, error::{Internal, Unauthorized}, - password::Password, - repo::login::Login, + login::{Login, Password}, }; -use super::{app, extract::IdentityToken}; +use super::app; +use crate::token::extract::IdentityToken; #[cfg(test)] mod test; diff --git a/src/login/token/id.rs b/src/login/token/id.rs deleted file mode 100644 index 9ef063c..0000000 --- a/src/login/token/id.rs +++ /dev/null @@ -1,27 +0,0 @@ -use std::fmt; - -use crate::id::Id as BaseId; - -// Stable identifier for a token. Prefixed with `T`. -#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] -#[sqlx(transparent)] -#[serde(transparent)] -pub struct Id(BaseId); - -impl From for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("T") - } -} - -impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} diff --git a/src/login/token/mod.rs b/src/login/token/mod.rs deleted file mode 100644 index d563a88..0000000 --- a/src/login/token/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod id; - -pub use self::id::Id; diff --git a/src/login/types.rs b/src/login/types.rs index a210977..d53d436 100644 --- a/src/login/types.rs +++ b/src/login/types.rs @@ -1,4 +1,4 @@ -use crate::login::token; +use crate::token; #[derive(Clone, Debug)] pub struct TokenRevoked { diff --git a/src/message/mod.rs b/src/message/mod.rs index d563a88..9a9bf14 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -1,3 +1,9 @@ mod id; pub use self::id::Id; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Message { + pub id: Id, + pub body: String, +} diff --git a/src/password.rs b/src/password.rs deleted file mode 100644 index da3930f..0000000 --- a/src/password.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::fmt; - -use argon2::Argon2; -use password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}; -use rand_core::OsRng; - -#[derive(Debug, sqlx::Type)] -#[sqlx(transparent)] -pub struct StoredHash(String); - -impl StoredHash { - pub fn verify(&self, password: &Password) -> Result { - let hash = PasswordHash::new(&self.0)?; - - match Argon2::default().verify_password(password.as_bytes(), &hash) { - // Successful authentication, not an error - Ok(()) => Ok(true), - // Unsuccessful authentication, also not an error - Err(password_hash::errors::Error::Password) => Ok(false), - // Password validation failed for some other reason, treat as an error - Err(err) => Err(err), - } - } -} - -#[derive(serde::Deserialize)] -#[serde(transparent)] -pub struct Password(String); - -impl Password { - pub fn hash(&self) -> Result { - let Self(password) = self; - let salt = SaltString::generate(&mut OsRng); - let argon2 = Argon2::default(); - let hash = argon2 - .hash_password(password.as_bytes(), &salt)? - .to_string(); - Ok(StoredHash(hash)) - } - - fn as_bytes(&self) -> &[u8] { - let Self(value) = self; - value.as_bytes() - } -} - -impl fmt::Debug for Password { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Password").field(&"********").finish() - } -} - -#[cfg(test)] -impl From for Password { - fn from(password: String) -> Self { - Self(password) - } -} diff --git a/src/repo/channel.rs b/src/repo/channel.rs index 9f1d930..18cd81f 100644 --- a/src/repo/channel.rs +++ b/src/repo/channel.rs @@ -1,10 +1,9 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use super::sequence::Sequence; use crate::{ - channel::Id, + channel::{Channel, Id}, clock::DateTime, - events::types::{self}, + event::{types, Sequence}, }; pub trait Provider { @@ -19,16 +18,6 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Channels<'t>(&'t mut SqliteConnection); -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Channel { - pub id: Id, - pub name: String, - #[serde(skip)] - pub created_at: DateTime, - #[serde(skip)] - pub created_sequence: Sequence, -} - impl<'c> Channels<'c> { pub async fn create( &mut self, diff --git a/src/repo/login.rs b/src/repo/login.rs new file mode 100644 index 0000000..d1a02c4 --- /dev/null +++ b/src/repo/login.rs @@ -0,0 +1,50 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::login::{password::StoredHash, Id, Login}; + +pub trait Provider { + fn logins(&mut self) -> Logins; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn logins(&mut self) -> Logins { + Logins(self) + } +} + +pub struct Logins<'t>(&'t mut SqliteConnection); + +impl<'c> Logins<'c> { + pub async fn create( + &mut self, + name: &str, + password_hash: &StoredHash, + ) -> Result { + let id = Id::generate(); + + let login = sqlx::query_as!( + Login, + r#" + insert or fail + into login (id, name, password_hash) + values ($1, $2, $3) + returning + id as "id: Id", + name + "#, + id, + name, + password_hash, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(login) + } +} + +impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { + fn from(tx: &'t mut SqliteConnection) -> Self { + Self(tx) + } +} diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs deleted file mode 100644 index ab61106..0000000 --- a/src/repo/login/extract.rs +++ /dev/null @@ -1,15 +0,0 @@ -use axum::{extract::FromRequestParts, http::request::Parts}; - -use super::Login; -use crate::{app::App, login::extract::Identity}; - -#[async_trait::async_trait] -impl FromRequestParts for Login { - type Rejection = >::Rejection; - - async fn from_request_parts(parts: &mut Parts, state: &App) -> Result { - let identity = Identity::from_request_parts(parts, state).await?; - - Ok(identity.login) - } -} diff --git a/src/repo/login/mod.rs b/src/repo/login/mod.rs deleted file mode 100644 index 4ff7a96..0000000 --- a/src/repo/login/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod extract; -mod store; - -pub use self::store::{Login, Provider}; diff --git a/src/repo/login/store.rs b/src/repo/login/store.rs deleted file mode 100644 index 47d1a7c..0000000 --- a/src/repo/login/store.rs +++ /dev/null @@ -1,63 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{login::Id, password::StoredHash}; - -pub trait Provider { - fn logins(&mut self) -> Logins; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn logins(&mut self) -> Logins { - Logins(self) - } -} - -pub struct Logins<'t>(&'t mut SqliteConnection); - -// This also implements FromRequestParts (see `./extract.rs`). As a result, it -// can be used as an extractor for endpoints that want to require login, or for -// endpoints that need to behave differently depending on whether the client is -// or is not logged in. -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Login { - pub id: Id, - pub name: String, - // The omission of the hashed password is deliberate, to minimize the - // chance that it ends up tangled up in debug output or in some other chunk - // of logic elsewhere. -} - -impl<'c> Logins<'c> { - pub async fn create( - &mut self, - name: &str, - password_hash: &StoredHash, - ) -> Result { - let id = Id::generate(); - - let login = sqlx::query_as!( - Login, - r#" - insert or fail - into login (id, name, password_hash) - values ($1, $2, $3) - returning - id as "id: Id", - name - "#, - id, - name, - password_hash, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(login) - } -} - -impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { - fn from(tx: &'t mut SqliteConnection) -> Self { - Self(tx) - } -} diff --git a/src/repo/message.rs b/src/repo/message.rs deleted file mode 100644 index acde3ea..0000000 --- a/src/repo/message.rs +++ /dev/null @@ -1,7 +0,0 @@ -use crate::message::Id; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Message { - pub id: Id, - pub body: String, -} diff --git a/src/repo/mod.rs b/src/repo/mod.rs index 8f271f4..69ad82c 100644 --- a/src/repo/mod.rs +++ b/src/repo/mod.rs @@ -1,7 +1,6 @@ pub mod channel; pub mod error; pub mod login; -pub mod message; pub mod pool; pub mod sequence; pub mod token; diff --git a/src/repo/sequence.rs b/src/repo/sequence.rs index c47b41c..c985869 100644 --- a/src/repo/sequence.rs +++ b/src/repo/sequence.rs @@ -1,7 +1,7 @@ -use std::fmt; - use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; +use crate::event::Sequence; + pub trait Provider { fn sequence(&mut self) -> Sequences; } @@ -42,26 +42,3 @@ impl<'c> Sequences<'c> { Ok(next) } } - -#[derive( - Clone, - Copy, - Debug, - Eq, - Ord, - PartialEq, - PartialOrd, - serde::Deserialize, - serde::Serialize, - sqlx::Type, -)] -#[serde(transparent)] -#[sqlx(transparent)] -pub struct Sequence(i64); - -impl fmt::Display for Sequence { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self(value) = self; - value.fmt(f) - } -} diff --git a/src/repo/token.rs b/src/repo/token.rs index 79e5c54..5f64dac 100644 --- a/src/repo/token.rs +++ b/src/repo/token.rs @@ -1,10 +1,10 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use uuid::Uuid; -use super::login::Login; use crate::{ clock::DateTime, - login::{self, extract::IdentitySecret, token::Id}, + login::{self, Login}, + token::{Id, Secret}, }; pub trait Provider { @@ -26,7 +26,7 @@ impl<'c> Tokens<'c> { &mut self, login: &Login, issued_at: &DateTime, - ) -> Result { + ) -> Result { let id = Id::generate(); let secret = Uuid::new_v4().to_string(); @@ -35,7 +35,7 @@ impl<'c> Tokens<'c> { insert into token (id, secret, login, issued_at, last_used_at) values ($1, $2, $3, $4, $4) - returning secret as "secret!: IdentitySecret" + returning secret as "secret!: Secret" "#, id, secret, @@ -103,7 +103,7 @@ impl<'c> Tokens<'c> { // timestamp will be set to `used_at`. pub async fn validate( &mut self, - secret: &IdentitySecret, + secret: &Secret, used_at: &DateTime, ) -> Result<(Id, Login), sqlx::Error> { // I would use `update … returning` to do this in one query, but diff --git a/src/test/fixtures/channel.rs b/src/test/fixtures/channel.rs index 8744470..b678717 100644 --- a/src/test/fixtures/channel.rs +++ b/src/test/fixtures/channel.rs @@ -4,7 +4,7 @@ use faker_rand::{ }; use rand; -use crate::{app::App, clock::RequestedAt, repo::channel::Channel}; +use crate::{app::App, channel::Channel, clock::RequestedAt}; pub async fn create(app: &App, created_at: &RequestedAt) -> Channel { let name = propose(); diff --git a/src/test/fixtures/filter.rs b/src/test/fixtures/filter.rs index c31fa58..d1939a5 100644 --- a/src/test/fixtures/filter.rs +++ b/src/test/fixtures/filter.rs @@ -1,6 +1,6 @@ use futures::future; -use crate::events::types; +use crate::event::types; pub fn messages() -> impl FnMut(&types::ChannelEvent) -> future::Ready { |event| future::ready(matches!(event.data, types::ChannelEventData::Message(_))) diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs index 633fb8a..9e8e403 100644 --- a/src/test/fixtures/identity.rs +++ b/src/test/fixtures/identity.rs @@ -3,8 +3,11 @@ use uuid::Uuid; use crate::{ app::App, clock::RequestedAt, - login::extract::{Identity, IdentitySecret, IdentityToken}, - password::Password, + login::Password, + token::{ + extract::{Identity, IdentityToken}, + Secret, + }, }; pub fn not_logged_in() -> IdentityToken { @@ -38,7 +41,7 @@ pub async fn identity(app: &App, login: &(String, Password), issued_at: &Request from_token(app, &secret, issued_at).await } -pub fn secret(identity: &IdentityToken) -> IdentitySecret { +pub fn secret(identity: &IdentityToken) -> Secret { identity.secret().expect("identity contained a secret") } diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs index d6a321b..00c2789 100644 --- a/src/test/fixtures/login.rs +++ b/src/test/fixtures/login.rs @@ -3,8 +3,7 @@ use uuid::Uuid; use crate::{ app::App, - password::Password, - repo::login::{self, Login}, + login::{self, Login, Password}, }; pub async fn create_with_password(app: &App) -> (String, Password) { diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs index bfca8cd..fd50887 100644 --- a/src/test/fixtures/message.rs +++ b/src/test/fixtures/message.rs @@ -1,11 +1,6 @@ use faker_rand::lorem::Paragraphs; -use crate::{ - app::App, - clock::RequestedAt, - events::types, - repo::{channel::Channel, login::Login}, -}; +use crate::{app::App, channel::Channel, clock::RequestedAt, event::types, login::Login}; pub async fn send( app: &App, diff --git a/src/token/extract/identity.rs b/src/token/extract/identity.rs new file mode 100644 index 0000000..42c7c60 --- /dev/null +++ b/src/token/extract/identity.rs @@ -0,0 +1,75 @@ +use axum::{ + extract::{FromRequestParts, State}, + http::request::Parts, + response::{IntoResponse, Response}, +}; + +use super::IdentityToken; + +use crate::{ + app::App, + clock::RequestedAt, + error::{Internal, Unauthorized}, + login::{app::ValidateError, Login}, + token, +}; + +#[derive(Clone, Debug)] +pub struct Identity { + pub token: token::Id, + pub login: Login, +} + +#[async_trait::async_trait] +impl FromRequestParts for Identity { + type Rejection = LoginError; + + async fn from_request_parts(parts: &mut Parts, state: &App) -> Result { + // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on + // stable), the following can be replaced: + // + // ``` + // let Ok(identity_token) = IdentityToken::from_request_parts( + // parts, + // state, + // ).await; + // ``` + let identity_token = IdentityToken::from_request_parts(parts, state).await?; + let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?; + + let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; + + let app = State::::from_request_parts(parts, state).await?; + match app.logins().validate(&secret, &used_at).await { + Ok((token, login)) => Ok(Identity { token, login }), + Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), + Err(other) => Err(other.into()), + } + } +} + +pub enum LoginError { + Failure(E), + Unauthorized, +} + +impl IntoResponse for LoginError +where + E: IntoResponse, +{ + fn into_response(self) -> Response { + match self { + Self::Unauthorized => Unauthorized.into_response(), + Self::Failure(e) => e.into_response(), + } + } +} + +impl From for LoginError +where + E: Into, +{ + fn from(err: E) -> Self { + Self::Failure(err.into()) + } +} diff --git a/src/token/extract/identity_token.rs b/src/token/extract/identity_token.rs new file mode 100644 index 0000000..0a47a43 --- /dev/null +++ b/src/token/extract/identity_token.rs @@ -0,0 +1,94 @@ +use std::fmt; + +use axum::{ + extract::FromRequestParts, + http::request::Parts, + response::{IntoResponseParts, ResponseParts}, +}; +use axum_extra::extract::cookie::{Cookie, CookieJar}; + +use crate::token::Secret; + +// The usage pattern here - receive the extractor as an argument, return it in +// the response - is heavily modelled after CookieJar's own intended usage. +#[derive(Clone)] +pub struct IdentityToken { + cookies: CookieJar, +} + +impl fmt::Debug for IdentityToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IdentityToken") + .field("identity", &self.secret()) + .finish() + } +} + +impl IdentityToken { + // Creates a new, unpopulated identity token store. + #[cfg(test)] + pub fn new() -> Self { + Self { + cookies: CookieJar::new(), + } + } + + // Get the identity secret sent in the request, if any. If the identity + // was not sent, or if it has previously been [clear]ed, then this will + // return [None]. If the identity has previously been [set], then this + // will return that secret, regardless of what the request originally + // included. + pub fn secret(&self) -> Option { + self.cookies + .get(IDENTITY_COOKIE) + .map(Cookie::value) + .map(Secret::from) + } + + // Positively set the identity secret, and ensure that it will be sent + // back to the client when this extractor is included in a response. + pub fn set(self, secret: impl Into) -> Self { + let secret = secret.into().reveal(); + let identity_cookie = Cookie::build((IDENTITY_COOKIE, secret)) + .http_only(true) + .path("/api/") + .permanent() + .build(); + + Self { + cookies: self.cookies.add(identity_cookie), + } + } + + // Remove the identity secret and ensure that it will be cleared when this + // extractor is included in a response. + pub fn clear(self) -> Self { + Self { + cookies: self.cookies.remove(IDENTITY_COOKIE), + } + } +} + +const IDENTITY_COOKIE: &str = "identity"; + +#[async_trait::async_trait] +impl FromRequestParts for IdentityToken +where + S: Send + Sync, +{ + type Rejection = >::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let cookies = CookieJar::from_request_parts(parts, state).await?; + Ok(Self { cookies }) + } +} + +impl IntoResponseParts for IdentityToken { + type Error = ::Error; + + fn into_response_parts(self, res: ResponseParts) -> Result { + let Self { cookies } = self; + cookies.into_response_parts(res) + } +} diff --git a/src/token/extract/mod.rs b/src/token/extract/mod.rs new file mode 100644 index 0000000..b4800ae --- /dev/null +++ b/src/token/extract/mod.rs @@ -0,0 +1,4 @@ +mod identity; +mod identity_token; + +pub use self::{identity::Identity, identity_token::IdentityToken}; diff --git a/src/token/id.rs b/src/token/id.rs new file mode 100644 index 0000000..9ef063c --- /dev/null +++ b/src/token/id.rs @@ -0,0 +1,27 @@ +use std::fmt; + +use crate::id::Id as BaseId; + +// Stable identifier for a token. Prefixed with `T`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("T") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/token/mod.rs b/src/token/mod.rs new file mode 100644 index 0000000..c98b8c2 --- /dev/null +++ b/src/token/mod.rs @@ -0,0 +1,5 @@ +pub mod extract; +mod id; +mod secret; + +pub use self::{id::Id, secret::Secret}; diff --git a/src/token/secret.rs b/src/token/secret.rs new file mode 100644 index 0000000..28c93bb --- /dev/null +++ b/src/token/secret.rs @@ -0,0 +1,27 @@ +use std::fmt; + +#[derive(sqlx::Type)] +#[sqlx(transparent)] +pub struct Secret(String); + +impl Secret { + pub fn reveal(self) -> String { + let Self(secret) = self; + secret + } +} + +impl fmt::Debug for Secret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IdentityToken").field(&"********").finish() + } +} + +impl From for Secret +where + S: Into, +{ + fn from(value: S) -> Self { + Self(value.into()) + } +} -- cgit v1.2.3 From 5d3392799f88c5a3d3f9c656c73d6e8ac5c4d793 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Wed, 2 Oct 2024 01:02:58 -0400 Subject: Split login and token handling. --- src/app.rs | 21 +++-- src/event/routes.rs | 5 +- src/event/routes/test.rs | 4 +- src/expire.rs | 2 +- src/login/app.rs | 169 ++------------------------------------- src/login/broadcaster.rs | 3 - src/login/mod.rs | 4 +- src/login/repo.rs | 50 ++++++++++++ src/login/repo/auth.rs | 50 ------------ src/login/repo/mod.rs | 1 - src/login/routes.rs | 10 +-- src/login/routes/test/login.rs | 13 ++- src/login/routes/test/logout.rs | 7 +- src/login/types.rs | 12 --- src/repo/login.rs | 50 ------------ src/repo/mod.rs | 2 - src/repo/token.rs | 151 ----------------------------------- src/test/fixtures/identity.rs | 4 +- src/token/app.rs | 170 ++++++++++++++++++++++++++++++++++++++++ src/token/broadcaster.rs | 4 + src/token/event.rs | 12 +++ src/token/extract/identity.rs | 6 +- src/token/mod.rs | 4 + src/token/repo/auth.rs | 50 ++++++++++++ src/token/repo/mod.rs | 4 + src/token/repo/token.rs | 151 +++++++++++++++++++++++++++++++++++ 26 files changed, 486 insertions(+), 473 deletions(-) delete mode 100644 src/login/broadcaster.rs create mode 100644 src/login/repo.rs delete mode 100644 src/login/repo/auth.rs delete mode 100644 src/login/repo/mod.rs delete mode 100644 src/login/types.rs delete mode 100644 src/repo/login.rs delete mode 100644 src/repo/token.rs create mode 100644 src/token/app.rs create mode 100644 src/token/broadcaster.rs create mode 100644 src/token/event.rs create mode 100644 src/token/repo/auth.rs create mode 100644 src/token/repo/mod.rs create mode 100644 src/token/repo/token.rs diff --git a/src/app.rs b/src/app.rs index 84a6357..5542e5f 100644 --- a/src/app.rs +++ b/src/app.rs @@ -3,34 +3,39 @@ use sqlx::sqlite::SqlitePool; use crate::{ channel::app::Channels, event::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, - login::{app::Logins, broadcaster::Broadcaster as LoginBroadcaster}, + login::app::Logins, + token::{app::Tokens, broadcaster::Broadcaster as TokenBroadcaster}, }; #[derive(Clone)] pub struct App { db: SqlitePool, events: EventBroadcaster, - logins: LoginBroadcaster, + tokens: TokenBroadcaster, } impl App { pub fn from(db: SqlitePool) -> Self { let events = EventBroadcaster::default(); - let logins = LoginBroadcaster::default(); - Self { db, events, logins } + let tokens = TokenBroadcaster::default(); + Self { db, events, tokens } } } impl App { - pub const fn logins(&self) -> Logins { - Logins::new(&self.db, &self.logins) + pub const fn channels(&self) -> Channels { + Channels::new(&self.db, &self.events) } pub const fn events(&self) -> Events { Events::new(&self.db, &self.events) } - pub const fn channels(&self) -> Channels { - Channels::new(&self.db, &self.events) + pub const fn logins(&self) -> Logins { + Logins::new(&self.db) + } + + pub const fn tokens(&self) -> Tokens { + Tokens::new(&self.db, &self.tokens) } } diff --git a/src/event/routes.rs b/src/event/routes.rs index 77761ca..50ac435 100644 --- a/src/event/routes.rs +++ b/src/event/routes.rs @@ -15,8 +15,7 @@ use crate::{ app::App, error::{Internal, Unauthorized}, event::Sequence, - login::app::ValidateError, - token::extract::Identity, + token::{app::ValidateError, extract::Identity}, }; #[cfg(test)] @@ -42,7 +41,7 @@ async fn events( .or(query.resume_point); let stream = app.events().subscribe(resume_at).await?; - let stream = app.logins().limit_stream(identity.token, stream).await?; + let stream = app.tokens().limit_stream(identity.token, stream).await?; Ok(Events(stream)) } diff --git a/src/event/routes/test.rs b/src/event/routes/test.rs index 9a3b12a..d1ac3b4 100644 --- a/src/event/routes/test.rs +++ b/src/event/routes/test.rs @@ -371,7 +371,7 @@ async fn terminates_on_token_expiry() { // Verify the resulting stream's behaviour - app.logins() + app.tokens() .expire(&fixtures::now()) .await .expect("expiring tokens succeeds"); @@ -418,7 +418,7 @@ async fn terminates_on_logout() { // Verify the resulting stream's behaviour - app.logins() + app.tokens() .logout(&subscriber.token) .await .expect("expiring tokens succeeds"); diff --git a/src/expire.rs b/src/expire.rs index 16006d1..a8eb8ad 100644 --- a/src/expire.rs +++ b/src/expire.rs @@ -13,7 +13,7 @@ pub async fn middleware( req: Request, next: Next, ) -> Result { - app.logins().expire(&expired_at).await?; + app.tokens().expire(&expired_at).await?; app.events().expire(&expired_at).await?; app.channels().expire(&expired_at).await?; Ok(next.run(req).await) diff --git a/src/login/app.rs b/src/login/app.rs index 60475af..69c1055 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -1,30 +1,17 @@ -use chrono::TimeDelta; -use futures::{ - future, - stream::{self, StreamExt as _}, - Stream, -}; use sqlx::sqlite::SqlitePool; -use super::{broadcaster::Broadcaster, repo::auth::Provider as _, types, Login}; -use crate::{ - clock::DateTime, - event::Sequence, - login::Password, - repo::{ - error::NotFound as _, login::Provider as _, sequence::Provider as _, token::Provider as _, - }, - token::{self, Secret}, -}; +use crate::{event::Sequence, repo::sequence::Provider as _}; + +#[cfg(test)] +use super::{repo::Provider as _, Login, Password}; pub struct Logins<'a> { db: &'a SqlitePool, - logins: &'a Broadcaster, } impl<'a> Logins<'a> { - pub const fn new(db: &'a SqlitePool, logins: &'a Broadcaster) -> Self { - Self { db, logins } + pub const fn new(db: &'a SqlitePool) -> Self { + Self { db } } pub async fn boot_point(&self) -> Result { @@ -35,33 +22,6 @@ impl<'a> Logins<'a> { Ok(sequence) } - pub async fn login( - &self, - name: &str, - password: &Password, - login_at: &DateTime, - ) -> Result { - let mut tx = self.db.begin().await?; - - let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? { - if stored_hash.verify(password)? { - // Password verified; use the login. - login - } else { - // Password NOT verified. - return Err(LoginError::Rejected); - } - } else { - let password_hash = password.hash()?; - tx.logins().create(name, &password_hash).await? - }; - - let token = tx.tokens().issue(&login, login_at).await?; - tx.commit().await?; - - Ok(token) - } - #[cfg(test)] pub async fn create(&self, name: &str, password: &Password) -> Result { let password_hash = password.hash()?; @@ -72,109 +32,6 @@ impl<'a> Logins<'a> { Ok(login) } - - pub async fn validate( - &self, - secret: &Secret, - used_at: &DateTime, - ) -> Result<(token::Id, Login), ValidateError> { - let mut tx = self.db.begin().await?; - let login = tx - .tokens() - .validate(secret, used_at) - .await - .not_found(|| ValidateError::InvalidToken)?; - tx.commit().await?; - - Ok(login) - } - - pub async fn limit_stream( - &self, - token: token::Id, - events: impl Stream + std::fmt::Debug, - ) -> Result + std::fmt::Debug, ValidateError> - where - E: std::fmt::Debug, - { - // Subscribe, first. - let token_events = self.logins.subscribe(); - - // Check that the token is valid at this point in time, second. If it is, then - // any future revocations will appear in the subscription. If not, bail now. - // - // It's possible, otherwise, to get to this point with a token that _was_ valid - // at the start of the request, but which was invalided _before_ the - // `subscribe()` call. In that case, the corresponding revocation event will - // simply be missed, since the `token_events` stream subscribed after the fact. - // This check cancels guarding the stream here. - // - // Yes, this is a weird niche edge case. Most things don't double-check, because - // they aren't expected to run long enough for the token's revocation to - // matter. Supervising a stream, on the other hand, will run for a - // _long_ time; if we miss the race here, we'll never actually carry out the - // supervision. - let mut tx = self.db.begin().await?; - tx.tokens() - .require(&token) - .await - .not_found(|| ValidateError::InvalidToken)?; - tx.commit().await?; - - // Then construct the guarded stream. First, project both streams into - // `GuardedEvent`. - let token_events = token_events - .filter(move |event| future::ready(event.token == token)) - .map(|_| GuardedEvent::TokenRevoked); - let events = events.map(|event| GuardedEvent::Event(event)); - - // Merge the two streams, then unproject them, stopping at - // `GuardedEvent::TokenRevoked`. - let stream = stream::select(token_events, events).scan((), |(), event| { - future::ready(match event { - GuardedEvent::Event(event) => Some(event), - GuardedEvent::TokenRevoked => None, - }) - }); - - Ok(stream) - } - - pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { - // Somewhat arbitrarily, expire after 7 days. - let expire_at = relative_to.to_owned() - TimeDelta::days(7); - - let mut tx = self.db.begin().await?; - let tokens = tx.tokens().expire(&expire_at).await?; - tx.commit().await?; - - for event in tokens.into_iter().map(types::TokenRevoked::from) { - self.logins.broadcast(&event); - } - - Ok(()) - } - - pub async fn logout(&self, token: &token::Id) -> Result<(), ValidateError> { - let mut tx = self.db.begin().await?; - tx.tokens().revoke(token).await?; - tx.commit().await?; - - self.logins - .broadcast(&types::TokenRevoked::from(token.clone())); - - Ok(()) - } -} - -#[derive(Debug, thiserror::Error)] -pub enum LoginError { - #[error("invalid login")] - Rejected, - #[error(transparent)] - DatabaseError(#[from] sqlx::Error), - #[error(transparent)] - PasswordHashError(#[from] password_hash::Error), } #[cfg(test)] @@ -184,17 +41,3 @@ pub enum CreateError { DatabaseError(#[from] sqlx::Error), PasswordHashError(#[from] password_hash::Error), } - -#[derive(Debug, thiserror::Error)] -pub enum ValidateError { - #[error("invalid token")] - InvalidToken, - #[error(transparent)] - DatabaseError(#[from] sqlx::Error), -} - -#[derive(Debug)] -enum GuardedEvent { - TokenRevoked, - Event(E), -} diff --git a/src/login/broadcaster.rs b/src/login/broadcaster.rs deleted file mode 100644 index 8e1fb3a..0000000 --- a/src/login/broadcaster.rs +++ /dev/null @@ -1,3 +0,0 @@ -use crate::{broadcast, login::types}; - -pub type Broadcaster = broadcast::Broadcaster; diff --git a/src/login/mod.rs b/src/login/mod.rs index 91c1821..65e3ada 100644 --- a/src/login/mod.rs +++ b/src/login/mod.rs @@ -1,11 +1,9 @@ pub mod app; -pub mod broadcaster; pub mod extract; mod id; pub mod password; -mod repo; +pub mod repo; mod routes; -pub mod types; pub use self::{id::Id, password::Password, routes::router}; diff --git a/src/login/repo.rs b/src/login/repo.rs new file mode 100644 index 0000000..d1a02c4 --- /dev/null +++ b/src/login/repo.rs @@ -0,0 +1,50 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::login::{password::StoredHash, Id, Login}; + +pub trait Provider { + fn logins(&mut self) -> Logins; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn logins(&mut self) -> Logins { + Logins(self) + } +} + +pub struct Logins<'t>(&'t mut SqliteConnection); + +impl<'c> Logins<'c> { + pub async fn create( + &mut self, + name: &str, + password_hash: &StoredHash, + ) -> Result { + let id = Id::generate(); + + let login = sqlx::query_as!( + Login, + r#" + insert or fail + into login (id, name, password_hash) + values ($1, $2, $3) + returning + id as "id: Id", + name + "#, + id, + name, + password_hash, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(login) + } +} + +impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { + fn from(tx: &'t mut SqliteConnection) -> Self { + Self(tx) + } +} diff --git a/src/login/repo/auth.rs b/src/login/repo/auth.rs deleted file mode 100644 index b299697..0000000 --- a/src/login/repo/auth.rs +++ /dev/null @@ -1,50 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::login::{self, password::StoredHash, Login}; - -pub trait Provider { - fn auth(&mut self) -> Auth; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn auth(&mut self) -> Auth { - Auth(self) - } -} - -pub struct Auth<'t>(&'t mut SqliteConnection); - -impl<'t> Auth<'t> { - // Retrieves a login by name, plus its stored password hash for - // verification. If there's no login with the requested name, this will - // return [None]. - pub async fn for_name( - &mut self, - name: &str, - ) -> Result, sqlx::Error> { - let found = sqlx::query!( - r#" - select - id as "id: login::Id", - name, - password_hash as "password_hash: StoredHash" - from login - where name = $1 - "#, - name, - ) - .map(|rec| { - ( - Login { - id: rec.id, - name: rec.name, - }, - rec.password_hash, - ) - }) - .fetch_optional(&mut *self.0) - .await?; - - Ok(found) - } -} diff --git a/src/login/repo/mod.rs b/src/login/repo/mod.rs deleted file mode 100644 index 0e4a05d..0000000 --- a/src/login/repo/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod auth; diff --git a/src/login/routes.rs b/src/login/routes.rs index b571bd5..0874cc3 100644 --- a/src/login/routes.rs +++ b/src/login/routes.rs @@ -11,11 +11,9 @@ use crate::{ clock::RequestedAt, error::{Internal, Unauthorized}, login::{Login, Password}, + token::{app, extract::IdentityToken}, }; -use super::app; -use crate::token::extract::IdentityToken; - #[cfg(test)] mod test; @@ -59,7 +57,7 @@ async fn on_login( Json(request): Json, ) -> Result<(IdentityToken, StatusCode), LoginError> { let token = app - .logins() + .tokens() .login(&request.name, &request.password, &now) .await .map_err(LoginError)?; @@ -95,8 +93,8 @@ async fn on_logout( Json(LogoutRequest {}): Json, ) -> Result<(IdentityToken, StatusCode), LogoutError> { if let Some(secret) = identity.secret() { - let (token, _) = app.logins().validate(&secret, &now).await?; - app.logins().logout(&token).await?; + let (token, _) = app.tokens().validate(&secret, &now).await?; + app.tokens().logout(&token).await?; } let identity = identity.clear(); diff --git a/src/login/routes/test/login.rs b/src/login/routes/test/login.rs index 81653ff..3c82738 100644 --- a/src/login/routes/test/login.rs +++ b/src/login/routes/test/login.rs @@ -3,10 +3,7 @@ use axum::{ http::StatusCode, }; -use crate::{ - login::{app, routes}, - test::fixtures, -}; +use crate::{login::routes, test::fixtures, token::app}; #[tokio::test] async fn new_identity() { @@ -37,7 +34,7 @@ async fn new_identity() { let validated_at = fixtures::now(); let (_, validated) = app - .logins() + .tokens() .validate(&secret, &validated_at) .await .expect("identity secret is valid"); @@ -74,7 +71,7 @@ async fn existing_identity() { let validated_at = fixtures::now(); let (_, validated_login) = app - .logins() + .tokens() .validate(&secret, &validated_at) .await .expect("identity secret is valid"); @@ -127,14 +124,14 @@ async fn token_expires() { // Verify the semantics let expired_at = fixtures::now(); - app.logins() + app.tokens() .expire(&expired_at) .await .expect("expiring tokens never fails"); let verified_at = fixtures::now(); let error = app - .logins() + .tokens() .validate(&secret, &verified_at) .await .expect_err("validating an expired token"); diff --git a/src/login/routes/test/logout.rs b/src/login/routes/test/logout.rs index 20b0d55..42b2534 100644 --- a/src/login/routes/test/logout.rs +++ b/src/login/routes/test/logout.rs @@ -3,10 +3,7 @@ use axum::{ http::StatusCode, }; -use crate::{ - login::{app, routes}, - test::fixtures, -}; +use crate::{login::routes, test::fixtures, token::app}; #[tokio::test] async fn successful() { @@ -37,7 +34,7 @@ async fn successful() { // Verify the semantics let error = app - .logins() + .tokens() .validate(&secret, &now) .await .expect_err("secret is invalid"); diff --git a/src/login/types.rs b/src/login/types.rs deleted file mode 100644 index d53d436..0000000 --- a/src/login/types.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::token; - -#[derive(Clone, Debug)] -pub struct TokenRevoked { - pub token: token::Id, -} - -impl From for TokenRevoked { - fn from(token: token::Id) -> Self { - Self { token } - } -} diff --git a/src/repo/login.rs b/src/repo/login.rs deleted file mode 100644 index d1a02c4..0000000 --- a/src/repo/login.rs +++ /dev/null @@ -1,50 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::login::{password::StoredHash, Id, Login}; - -pub trait Provider { - fn logins(&mut self) -> Logins; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn logins(&mut self) -> Logins { - Logins(self) - } -} - -pub struct Logins<'t>(&'t mut SqliteConnection); - -impl<'c> Logins<'c> { - pub async fn create( - &mut self, - name: &str, - password_hash: &StoredHash, - ) -> Result { - let id = Id::generate(); - - let login = sqlx::query_as!( - Login, - r#" - insert or fail - into login (id, name, password_hash) - values ($1, $2, $3) - returning - id as "id: Id", - name - "#, - id, - name, - password_hash, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(login) - } -} - -impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { - fn from(tx: &'t mut SqliteConnection) -> Self { - Self(tx) - } -} diff --git a/src/repo/mod.rs b/src/repo/mod.rs index 69ad82c..7abd46b 100644 --- a/src/repo/mod.rs +++ b/src/repo/mod.rs @@ -1,6 +1,4 @@ pub mod channel; pub mod error; -pub mod login; pub mod pool; pub mod sequence; -pub mod token; diff --git a/src/repo/token.rs b/src/repo/token.rs deleted file mode 100644 index 5f64dac..0000000 --- a/src/repo/token.rs +++ /dev/null @@ -1,151 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use uuid::Uuid; - -use crate::{ - clock::DateTime, - login::{self, Login}, - token::{Id, Secret}, -}; - -pub trait Provider { - fn tokens(&mut self) -> Tokens; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn tokens(&mut self) -> Tokens { - Tokens(self) - } -} - -pub struct Tokens<'t>(&'t mut SqliteConnection); - -impl<'c> Tokens<'c> { - // Issue a new token for an existing login. The issued_at timestamp will - // be used to control expiry, until the token is actually used. - pub async fn issue( - &mut self, - login: &Login, - issued_at: &DateTime, - ) -> Result { - let id = Id::generate(); - let secret = Uuid::new_v4().to_string(); - - let secret = sqlx::query_scalar!( - r#" - insert - into token (id, secret, login, issued_at, last_used_at) - values ($1, $2, $3, $4, $4) - returning secret as "secret!: Secret" - "#, - id, - secret, - login.id, - issued_at, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(secret) - } - - pub async fn require(&mut self, token: &Id) -> Result<(), sqlx::Error> { - sqlx::query_scalar!( - r#" - select id as "id: Id" - from token - where id = $1 - "#, - token, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(()) - } - - // Revoke a token by its secret. - pub async fn revoke(&mut self, token: &Id) -> Result<(), sqlx::Error> { - sqlx::query_scalar!( - r#" - delete - from token - where id = $1 - returning id as "id: Id" - "#, - token, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(()) - } - - // Expire and delete all tokens that haven't been used more recently than - // `expire_at`. - pub async fn expire(&mut self, expire_at: &DateTime) -> Result, sqlx::Error> { - let tokens = sqlx::query_scalar!( - r#" - delete - from token - where last_used_at < $1 - returning id as "id: Id" - "#, - expire_at, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(tokens) - } - - // Validate a token by its secret, retrieving the associated Login record. - // Will return [None] if the token is not valid. The token's last-used - // timestamp will be set to `used_at`. - pub async fn validate( - &mut self, - secret: &Secret, - used_at: &DateTime, - ) -> Result<(Id, Login), sqlx::Error> { - // I would use `update … returning` to do this in one query, but - // sqlite3, as of this writing, does not allow an update's `returning` - // clause to reference columns from tables joined into the update. Two - // queries is fine, but it feels untidy. - sqlx::query!( - r#" - update token - set last_used_at = $1 - where secret = $2 - "#, - used_at, - secret, - ) - .execute(&mut *self.0) - .await?; - - let login = sqlx::query!( - r#" - select - token.id as "token_id: Id", - login.id as "login_id: login::Id", - name as "login_name" - from login - join token on login.id = token.login - where token.secret = $1 - "#, - secret, - ) - .map(|row| { - ( - row.token_id, - Login { - id: row.login_id, - name: row.login_name, - }, - ) - }) - .fetch_one(&mut *self.0) - .await?; - - Ok(login) - } -} diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs index 9e8e403..56b4ffa 100644 --- a/src/test/fixtures/identity.rs +++ b/src/test/fixtures/identity.rs @@ -17,7 +17,7 @@ pub fn not_logged_in() -> IdentityToken { pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt) -> IdentityToken { let (name, password) = login; let token = app - .logins() + .tokens() .login(name, password, now) .await .expect("should succeed given known-valid credentials"); @@ -28,7 +28,7 @@ pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt) pub async fn from_token(app: &App, token: &IdentityToken, issued_at: &RequestedAt) -> Identity { let secret = token.secret().expect("identity token has a secret"); let (token, login) = app - .logins() + .tokens() .validate(&secret, issued_at) .await .expect("always validates newly-issued secret"); diff --git a/src/token/app.rs b/src/token/app.rs new file mode 100644 index 0000000..1477a9f --- /dev/null +++ b/src/token/app.rs @@ -0,0 +1,170 @@ +use chrono::TimeDelta; +use futures::{ + future, + stream::{self, StreamExt as _}, + Stream, +}; +use sqlx::sqlite::SqlitePool; + +use super::{ + broadcaster::Broadcaster, event, repo::auth::Provider as _, repo::Provider as _, Id, Secret, +}; +use crate::{ + clock::DateTime, + login::{repo::Provider as _, Login, Password}, + repo::error::NotFound as _, +}; + +pub struct Tokens<'a> { + db: &'a SqlitePool, + tokens: &'a Broadcaster, +} + +impl<'a> Tokens<'a> { + pub const fn new(db: &'a SqlitePool, tokens: &'a Broadcaster) -> Self { + Self { db, tokens } + } + pub async fn login( + &self, + name: &str, + password: &Password, + login_at: &DateTime, + ) -> Result { + let mut tx = self.db.begin().await?; + + let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? { + if stored_hash.verify(password)? { + // Password verified; use the login. + login + } else { + // Password NOT verified. + return Err(LoginError::Rejected); + } + } else { + let password_hash = password.hash()?; + tx.logins().create(name, &password_hash).await? + }; + + let token = tx.tokens().issue(&login, login_at).await?; + tx.commit().await?; + + Ok(token) + } + + pub async fn validate( + &self, + secret: &Secret, + used_at: &DateTime, + ) -> Result<(Id, Login), ValidateError> { + let mut tx = self.db.begin().await?; + let login = tx + .tokens() + .validate(secret, used_at) + .await + .not_found(|| ValidateError::InvalidToken)?; + tx.commit().await?; + + Ok(login) + } + + pub async fn limit_stream( + &self, + token: Id, + events: impl Stream + std::fmt::Debug, + ) -> Result + std::fmt::Debug, ValidateError> + where + E: std::fmt::Debug, + { + // Subscribe, first. + let token_events = self.tokens.subscribe(); + + // Check that the token is valid at this point in time, second. If it is, then + // any future revocations will appear in the subscription. If not, bail now. + // + // It's possible, otherwise, to get to this point with a token that _was_ valid + // at the start of the request, but which was invalided _before_ the + // `subscribe()` call. In that case, the corresponding revocation event will + // simply be missed, since the `token_events` stream subscribed after the fact. + // This check cancels guarding the stream here. + // + // Yes, this is a weird niche edge case. Most things don't double-check, because + // they aren't expected to run long enough for the token's revocation to + // matter. Supervising a stream, on the other hand, will run for a + // _long_ time; if we miss the race here, we'll never actually carry out the + // supervision. + let mut tx = self.db.begin().await?; + tx.tokens() + .require(&token) + .await + .not_found(|| ValidateError::InvalidToken)?; + tx.commit().await?; + + // Then construct the guarded stream. First, project both streams into + // `GuardedEvent`. + let token_events = token_events + .filter(move |event| future::ready(event.token == token)) + .map(|_| GuardedEvent::TokenRevoked); + let events = events.map(|event| GuardedEvent::Event(event)); + + // Merge the two streams, then unproject them, stopping at + // `GuardedEvent::TokenRevoked`. + let stream = stream::select(token_events, events).scan((), |(), event| { + future::ready(match event { + GuardedEvent::Event(event) => Some(event), + GuardedEvent::TokenRevoked => None, + }) + }); + + Ok(stream) + } + + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 7 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(7); + + let mut tx = self.db.begin().await?; + let tokens = tx.tokens().expire(&expire_at).await?; + tx.commit().await?; + + for event in tokens.into_iter().map(event::TokenRevoked::from) { + self.tokens.broadcast(&event); + } + + Ok(()) + } + + pub async fn logout(&self, token: &Id) -> Result<(), ValidateError> { + let mut tx = self.db.begin().await?; + tx.tokens().revoke(token).await?; + tx.commit().await?; + + self.tokens + .broadcast(&event::TokenRevoked::from(token.clone())); + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum LoginError { + #[error("invalid login")] + Rejected, + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), + #[error(transparent)] + PasswordHashError(#[from] password_hash::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum ValidateError { + #[error("invalid token")] + InvalidToken, + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} + +#[derive(Debug)] +enum GuardedEvent { + TokenRevoked, + Event(E), +} diff --git a/src/token/broadcaster.rs b/src/token/broadcaster.rs new file mode 100644 index 0000000..8e2e006 --- /dev/null +++ b/src/token/broadcaster.rs @@ -0,0 +1,4 @@ +use super::event; +use crate::broadcast; + +pub type Broadcaster = broadcast::Broadcaster; diff --git a/src/token/event.rs b/src/token/event.rs new file mode 100644 index 0000000..d53d436 --- /dev/null +++ b/src/token/event.rs @@ -0,0 +1,12 @@ +use crate::token; + +#[derive(Clone, Debug)] +pub struct TokenRevoked { + pub token: token::Id, +} + +impl From for TokenRevoked { + fn from(token: token::Id) -> Self { + Self { token } + } +} diff --git a/src/token/extract/identity.rs b/src/token/extract/identity.rs index 42c7c60..60ad220 100644 --- a/src/token/extract/identity.rs +++ b/src/token/extract/identity.rs @@ -10,8 +10,8 @@ use crate::{ app::App, clock::RequestedAt, error::{Internal, Unauthorized}, - login::{app::ValidateError, Login}, - token, + login::Login, + token::{self, app::ValidateError}, }; #[derive(Clone, Debug)] @@ -40,7 +40,7 @@ impl FromRequestParts for Identity { let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; let app = State::::from_request_parts(parts, state).await?; - match app.logins().validate(&secret, &used_at).await { + match app.tokens().validate(&secret, &used_at).await { Ok((token, login)) => Ok(Identity { token, login }), Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), Err(other) => Err(other.into()), diff --git a/src/token/mod.rs b/src/token/mod.rs index c98b8c2..d122611 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -1,5 +1,9 @@ +pub mod app; +pub mod broadcaster; +mod event; pub mod extract; mod id; +mod repo; mod secret; pub use self::{id::Id, secret::Secret}; diff --git a/src/token/repo/auth.rs b/src/token/repo/auth.rs new file mode 100644 index 0000000..b299697 --- /dev/null +++ b/src/token/repo/auth.rs @@ -0,0 +1,50 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::login::{self, password::StoredHash, Login}; + +pub trait Provider { + fn auth(&mut self) -> Auth; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn auth(&mut self) -> Auth { + Auth(self) + } +} + +pub struct Auth<'t>(&'t mut SqliteConnection); + +impl<'t> Auth<'t> { + // Retrieves a login by name, plus its stored password hash for + // verification. If there's no login with the requested name, this will + // return [None]. + pub async fn for_name( + &mut self, + name: &str, + ) -> Result, sqlx::Error> { + let found = sqlx::query!( + r#" + select + id as "id: login::Id", + name, + password_hash as "password_hash: StoredHash" + from login + where name = $1 + "#, + name, + ) + .map(|rec| { + ( + Login { + id: rec.id, + name: rec.name, + }, + rec.password_hash, + ) + }) + .fetch_optional(&mut *self.0) + .await?; + + Ok(found) + } +} diff --git a/src/token/repo/mod.rs b/src/token/repo/mod.rs new file mode 100644 index 0000000..9169743 --- /dev/null +++ b/src/token/repo/mod.rs @@ -0,0 +1,4 @@ +pub mod auth; +mod token; + +pub use self::token::Provider; diff --git a/src/token/repo/token.rs b/src/token/repo/token.rs new file mode 100644 index 0000000..5f64dac --- /dev/null +++ b/src/token/repo/token.rs @@ -0,0 +1,151 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; +use uuid::Uuid; + +use crate::{ + clock::DateTime, + login::{self, Login}, + token::{Id, Secret}, +}; + +pub trait Provider { + fn tokens(&mut self) -> Tokens; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn tokens(&mut self) -> Tokens { + Tokens(self) + } +} + +pub struct Tokens<'t>(&'t mut SqliteConnection); + +impl<'c> Tokens<'c> { + // Issue a new token for an existing login. The issued_at timestamp will + // be used to control expiry, until the token is actually used. + pub async fn issue( + &mut self, + login: &Login, + issued_at: &DateTime, + ) -> Result { + let id = Id::generate(); + let secret = Uuid::new_v4().to_string(); + + let secret = sqlx::query_scalar!( + r#" + insert + into token (id, secret, login, issued_at, last_used_at) + values ($1, $2, $3, $4, $4) + returning secret as "secret!: Secret" + "#, + id, + secret, + login.id, + issued_at, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(secret) + } + + pub async fn require(&mut self, token: &Id) -> Result<(), sqlx::Error> { + sqlx::query_scalar!( + r#" + select id as "id: Id" + from token + where id = $1 + "#, + token, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(()) + } + + // Revoke a token by its secret. + pub async fn revoke(&mut self, token: &Id) -> Result<(), sqlx::Error> { + sqlx::query_scalar!( + r#" + delete + from token + where id = $1 + returning id as "id: Id" + "#, + token, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(()) + } + + // Expire and delete all tokens that haven't been used more recently than + // `expire_at`. + pub async fn expire(&mut self, expire_at: &DateTime) -> Result, sqlx::Error> { + let tokens = sqlx::query_scalar!( + r#" + delete + from token + where last_used_at < $1 + returning id as "id: Id" + "#, + expire_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(tokens) + } + + // Validate a token by its secret, retrieving the associated Login record. + // Will return [None] if the token is not valid. The token's last-used + // timestamp will be set to `used_at`. + pub async fn validate( + &mut self, + secret: &Secret, + used_at: &DateTime, + ) -> Result<(Id, Login), sqlx::Error> { + // I would use `update … returning` to do this in one query, but + // sqlite3, as of this writing, does not allow an update's `returning` + // clause to reference columns from tables joined into the update. Two + // queries is fine, but it feels untidy. + sqlx::query!( + r#" + update token + set last_used_at = $1 + where secret = $2 + "#, + used_at, + secret, + ) + .execute(&mut *self.0) + .await?; + + let login = sqlx::query!( + r#" + select + token.id as "token_id: Id", + login.id as "login_id: login::Id", + name as "login_name" + from login + join token on login.id = token.login + where token.secret = $1 + "#, + secret, + ) + .map(|row| { + ( + row.token_id, + Login { + id: row.login_id, + name: row.login_name, + }, + ) + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(login) + } +} -- cgit v1.2.3 From 6f07e6869bbf62903ac83c9bc061e7bde997e6a8 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Wed, 2 Oct 2024 01:10:09 -0400 Subject: Retire top-level `repo`. This helped me discover an organizational scheme I like more. --- src/channel/app.rs | 6 +- src/channel/mod.rs | 1 + src/channel/repo.rs | 165 +++++++++++++++++++++++++++++++++++++++++++++ src/cli.rs | 4 +- src/db.rs | 42 ++++++++++++ src/event/app.rs | 6 +- src/event/repo/mod.rs | 3 + src/event/repo/sequence.rs | 44 ++++++++++++ src/lib.rs | 2 +- src/login/app.rs | 2 +- src/repo/channel.rs | 165 --------------------------------------------- src/repo/error.rs | 23 ------- src/repo/mod.rs | 4 -- src/repo/pool.rs | 18 ----- src/repo/sequence.rs | 44 ------------ src/test/fixtures/mod.rs | 4 +- src/token/app.rs | 2 +- 17 files changed, 267 insertions(+), 268 deletions(-) create mode 100644 src/channel/repo.rs create mode 100644 src/db.rs create mode 100644 src/event/repo/sequence.rs delete mode 100644 src/repo/channel.rs delete mode 100644 src/repo/error.rs delete mode 100644 src/repo/mod.rs delete mode 100644 src/repo/pool.rs delete mode 100644 src/repo/sequence.rs diff --git a/src/channel/app.rs b/src/channel/app.rs index 1422651..ef0a63f 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -2,11 +2,9 @@ use chrono::TimeDelta; use sqlx::sqlite::SqlitePool; use crate::{ - channel::Channel, + channel::{repo::Provider as _, Channel}, clock::DateTime, - event::Sequence, - event::{broadcaster::Broadcaster, types::ChannelEvent}, - repo::{channel::Provider as _, sequence::Provider as _}, + event::{broadcaster::Broadcaster, repo::Provider as _, types::ChannelEvent, Sequence}, }; pub struct Channels<'a> { diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 02d0ed4..2672084 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -2,6 +2,7 @@ use crate::{clock::DateTime, event::Sequence}; pub mod app; mod id; +pub mod repo; mod routes; pub use self::{id::Id, routes::router}; diff --git a/src/channel/repo.rs b/src/channel/repo.rs new file mode 100644 index 0000000..18cd81f --- /dev/null +++ b/src/channel/repo.rs @@ -0,0 +1,165 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + channel::{Channel, Id}, + clock::DateTime, + event::{types, Sequence}, +}; + +pub trait Provider { + fn channels(&mut self) -> Channels; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn channels(&mut self) -> Channels { + Channels(self) + } +} + +pub struct Channels<'t>(&'t mut SqliteConnection); + +impl<'c> Channels<'c> { + pub async fn create( + &mut self, + name: &str, + created_at: &DateTime, + created_sequence: Sequence, + ) -> Result { + let id = Id::generate(); + let channel = sqlx::query_as!( + Channel, + r#" + insert + into channel (id, name, created_at, created_sequence) + values ($1, $2, $3, $4) + returning + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + "#, + id, + name, + created_at, + created_sequence, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(channel) + } + + pub async fn by_id(&mut self, channel: &Id) -> Result { + let channel = sqlx::query_as!( + Channel, + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where id = $1 + "#, + channel, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(channel) + } + + pub async fn all( + &mut self, + resume_point: Option, + ) -> Result, sqlx::Error> { + let channels = sqlx::query_as!( + Channel, + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where coalesce(created_sequence <= $1, true) + order by channel.name + "#, + resume_point, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } + + pub async fn replay( + &mut self, + resume_at: Option, + ) -> Result, sqlx::Error> { + let channels = sqlx::query_as!( + Channel, + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where coalesce(created_sequence > $1, true) + "#, + resume_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } + + pub async fn delete( + &mut self, + channel: &Channel, + deleted_at: &DateTime, + deleted_sequence: Sequence, + ) -> Result { + let channel = channel.id.clone(); + sqlx::query_scalar!( + r#" + delete from channel + where id = $1 + returning 1 as "row: i64" + "#, + channel, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(types::ChannelEvent { + sequence: deleted_sequence, + at: *deleted_at, + data: types::DeletedEvent { channel }.into(), + }) + } + + pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { + let channels = sqlx::query_as!( + Channel, + r#" + select + channel.id as "id: Id", + channel.name, + channel.created_at as "created_at: DateTime", + channel.created_sequence as "created_sequence: Sequence" + from channel + left join message + where created_at < $1 + and message.id is null + "#, + expired_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } +} diff --git a/src/cli.rs b/src/cli.rs index ee95ea6..893fae2 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -10,7 +10,7 @@ use clap::Parser; use sqlx::sqlite::SqlitePool; use tokio::net; -use crate::{app::App, channel, clock, event, expire, login, repo::pool}; +use crate::{app::App, channel, clock, db, event, expire, login}; /// Command-line entry point for running the `hi` server. /// @@ -100,7 +100,7 @@ impl Args { } async fn pool(&self) -> sqlx::Result { - pool::prepare(&self.database_url).await + db::prepare(&self.database_url).await } } diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..93a1169 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,42 @@ +use std::str::FromStr; + +use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; + +pub async fn prepare(url: &str) -> sqlx::Result { + let pool = create(url).await?; + sqlx::migrate!().run(&pool).await?; + Ok(pool) +} + +async fn create(database_url: &str) -> sqlx::Result { + let options = SqliteConnectOptions::from_str(database_url)? + .create_if_missing(true) + .optimize_on_close(true, /* analysis_limit */ None); + + let pool = SqlitePoolOptions::new().connect_with(options).await?; + Ok(pool) +} + +pub trait NotFound { + type Ok; + fn not_found(self, map: F) -> Result + where + E: From, + F: FnOnce() -> E; +} + +impl NotFound for Result { + type Ok = T; + + fn not_found(self, map: F) -> Result + where + E: From, + F: FnOnce() -> E, + { + match self { + Err(sqlx::Error::RowNotFound) => Err(map()), + Err(other) => Err(other.into()), + Ok(value) => Ok(value), + } + } +} diff --git a/src/event/app.rs b/src/event/app.rs index b5f2ecc..3d35f1a 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -12,11 +12,11 @@ use super::{ types::{self, ChannelEvent}, }; use crate::{ - channel, + channel::{self, repo::Provider as _}, clock::DateTime, - event::Sequence, + db::NotFound as _, + event::{repo::Provider as _, Sequence}, login::Login, - repo::{channel::Provider as _, error::NotFound as _, sequence::Provider as _}, }; pub struct Events<'a> { diff --git a/src/event/repo/mod.rs b/src/event/repo/mod.rs index e216a50..cee840c 100644 --- a/src/event/repo/mod.rs +++ b/src/event/repo/mod.rs @@ -1 +1,4 @@ pub mod message; +mod sequence; + +pub use self::sequence::Provider; diff --git a/src/event/repo/sequence.rs b/src/event/repo/sequence.rs new file mode 100644 index 0000000..c985869 --- /dev/null +++ b/src/event/repo/sequence.rs @@ -0,0 +1,44 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::event::Sequence; + +pub trait Provider { + fn sequence(&mut self) -> Sequences; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn sequence(&mut self) -> Sequences { + Sequences(self) + } +} + +pub struct Sequences<'t>(&'t mut SqliteConnection); + +impl<'c> Sequences<'c> { + pub async fn next(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + update event_sequence + set last_value = last_value + 1 + returning last_value as "next_value: Sequence" + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } + + pub async fn current(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + select last_value as "last_value: Sequence" + from event_sequence + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } +} diff --git a/src/lib.rs b/src/lib.rs index bbcb314..8ec13da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,13 +7,13 @@ mod broadcast; mod channel; pub mod cli; mod clock; +mod db; mod error; mod event; mod expire; mod id; mod login; mod message; -mod repo; #[cfg(test)] mod test; mod token; diff --git a/src/login/app.rs b/src/login/app.rs index 69c1055..15adb31 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -1,6 +1,6 @@ use sqlx::sqlite::SqlitePool; -use crate::{event::Sequence, repo::sequence::Provider as _}; +use crate::event::{repo::Provider as _, Sequence}; #[cfg(test)] use super::{repo::Provider as _, Login, Password}; diff --git a/src/repo/channel.rs b/src/repo/channel.rs deleted file mode 100644 index 18cd81f..0000000 --- a/src/repo/channel.rs +++ /dev/null @@ -1,165 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - channel::{Channel, Id}, - clock::DateTime, - event::{types, Sequence}, -}; - -pub trait Provider { - fn channels(&mut self) -> Channels; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn channels(&mut self) -> Channels { - Channels(self) - } -} - -pub struct Channels<'t>(&'t mut SqliteConnection); - -impl<'c> Channels<'c> { - pub async fn create( - &mut self, - name: &str, - created_at: &DateTime, - created_sequence: Sequence, - ) -> Result { - let id = Id::generate(); - let channel = sqlx::query_as!( - Channel, - r#" - insert - into channel (id, name, created_at, created_sequence) - values ($1, $2, $3, $4) - returning - id as "id: Id", - name, - created_at as "created_at: DateTime", - created_sequence as "created_sequence: Sequence" - "#, - id, - name, - created_at, - created_sequence, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(channel) - } - - pub async fn by_id(&mut self, channel: &Id) -> Result { - let channel = sqlx::query_as!( - Channel, - r#" - select - id as "id: Id", - name, - created_at as "created_at: DateTime", - created_sequence as "created_sequence: Sequence" - from channel - where id = $1 - "#, - channel, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(channel) - } - - pub async fn all( - &mut self, - resume_point: Option, - ) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, - r#" - select - id as "id: Id", - name, - created_at as "created_at: DateTime", - created_sequence as "created_sequence: Sequence" - from channel - where coalesce(created_sequence <= $1, true) - order by channel.name - "#, - resume_point, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(channels) - } - - pub async fn replay( - &mut self, - resume_at: Option, - ) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, - r#" - select - id as "id: Id", - name, - created_at as "created_at: DateTime", - created_sequence as "created_sequence: Sequence" - from channel - where coalesce(created_sequence > $1, true) - "#, - resume_at, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(channels) - } - - pub async fn delete( - &mut self, - channel: &Channel, - deleted_at: &DateTime, - deleted_sequence: Sequence, - ) -> Result { - let channel = channel.id.clone(); - sqlx::query_scalar!( - r#" - delete from channel - where id = $1 - returning 1 as "row: i64" - "#, - channel, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(types::ChannelEvent { - sequence: deleted_sequence, - at: *deleted_at, - data: types::DeletedEvent { channel }.into(), - }) - } - - pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, - r#" - select - channel.id as "id: Id", - channel.name, - channel.created_at as "created_at: DateTime", - channel.created_sequence as "created_sequence: Sequence" - from channel - left join message - where created_at < $1 - and message.id is null - "#, - expired_at, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(channels) - } -} diff --git a/src/repo/error.rs b/src/repo/error.rs deleted file mode 100644 index a5961e2..0000000 --- a/src/repo/error.rs +++ /dev/null @@ -1,23 +0,0 @@ -pub trait NotFound { - type Ok; - fn not_found(self, map: F) -> Result - where - E: From, - F: FnOnce() -> E; -} - -impl NotFound for Result { - type Ok = T; - - fn not_found(self, map: F) -> Result - where - E: From, - F: FnOnce() -> E, - { - match self { - Err(sqlx::Error::RowNotFound) => Err(map()), - Err(other) => Err(other.into()), - Ok(value) => Ok(value), - } - } -} diff --git a/src/repo/mod.rs b/src/repo/mod.rs deleted file mode 100644 index 7abd46b..0000000 --- a/src/repo/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod channel; -pub mod error; -pub mod pool; -pub mod sequence; diff --git a/src/repo/pool.rs b/src/repo/pool.rs deleted file mode 100644 index b4aa6fc..0000000 --- a/src/repo/pool.rs +++ /dev/null @@ -1,18 +0,0 @@ -use std::str::FromStr; - -use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; - -pub async fn prepare(url: &str) -> sqlx::Result { - let pool = create(url).await?; - sqlx::migrate!().run(&pool).await?; - Ok(pool) -} - -async fn create(database_url: &str) -> sqlx::Result { - let options = SqliteConnectOptions::from_str(database_url)? - .create_if_missing(true) - .optimize_on_close(true, /* analysis_limit */ None); - - let pool = SqlitePoolOptions::new().connect_with(options).await?; - Ok(pool) -} diff --git a/src/repo/sequence.rs b/src/repo/sequence.rs deleted file mode 100644 index c985869..0000000 --- a/src/repo/sequence.rs +++ /dev/null @@ -1,44 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::event::Sequence; - -pub trait Provider { - fn sequence(&mut self) -> Sequences; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn sequence(&mut self) -> Sequences { - Sequences(self) - } -} - -pub struct Sequences<'t>(&'t mut SqliteConnection); - -impl<'c> Sequences<'c> { - pub async fn next(&mut self) -> Result { - let next = sqlx::query_scalar!( - r#" - update event_sequence - set last_value = last_value + 1 - returning last_value as "next_value: Sequence" - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) - } - - pub async fn current(&mut self) -> Result { - let next = sqlx::query_scalar!( - r#" - select last_value as "last_value: Sequence" - from event_sequence - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) - } -} diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index d1dd0c3..76467ab 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -1,6 +1,6 @@ use chrono::{TimeDelta, Utc}; -use crate::{app::App, clock::RequestedAt, repo::pool}; +use crate::{app::App, clock::RequestedAt, db}; pub mod channel; pub mod filter; @@ -10,7 +10,7 @@ pub mod login; pub mod message; pub async fn scratch_app() -> App { - let pool = pool::prepare("sqlite::memory:") + let pool = db::prepare("sqlite::memory:") .await .expect("setting up in-memory sqlite database"); App::from(pool) diff --git a/src/token/app.rs b/src/token/app.rs index 1477a9f..030ec69 100644 --- a/src/token/app.rs +++ b/src/token/app.rs @@ -11,8 +11,8 @@ use super::{ }; use crate::{ clock::DateTime, + db::NotFound as _, login::{repo::Provider as _, Login, Password}, - repo::error::NotFound as _, }; pub struct Tokens<'a> { -- cgit v1.2.3 From 469613872f6fb19f4579b387e19b2bc38fa52f51 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Wed, 2 Oct 2024 01:31:43 -0400 Subject: Package up common event fields as Instant --- src/channel/app.rs | 11 +++--- src/channel/mod.rs | 6 ++-- src/channel/repo.rs | 74 ++++++++++++++++++++++++++------------ src/channel/routes/test/on_send.rs | 2 +- src/event/app.rs | 18 +++++----- src/event/mod.rs | 9 +++++ src/event/repo/message.rs | 42 +++++++++++++--------- src/event/repo/sequence.rs | 12 +++++-- src/event/routes.rs | 2 +- src/event/routes/test.rs | 8 ++--- src/event/types.rs | 22 +++--------- 11 files changed, 121 insertions(+), 85 deletions(-) diff --git a/src/channel/app.rs b/src/channel/app.rs index ef0a63f..b7e3a10 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -19,10 +19,10 @@ impl<'a> Channels<'a> { pub async fn create(&self, name: &str, created_at: &DateTime) -> Result { let mut tx = self.db.begin().await?; - let created_sequence = tx.sequence().next().await?; + let created = tx.sequence().next(created_at).await?; let channel = tx .channels() - .create(name, created_at, created_sequence) + .create(name, &created) .await .map_err(|err| CreateError::from_duplicate_name(err, name))?; tx.commit().await?; @@ -50,11 +50,8 @@ impl<'a> Channels<'a> { let mut events = Vec::with_capacity(expired.len()); for channel in expired { - let deleted_sequence = tx.sequence().next().await?; - let event = tx - .channels() - .delete(&channel, relative_to, deleted_sequence) - .await?; + let deleted = tx.sequence().next(relative_to).await?; + let event = tx.channels().delete(&channel, &deleted).await?; events.push(event); } diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 2672084..4baa7e3 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,4 +1,4 @@ -use crate::{clock::DateTime, event::Sequence}; +use crate::event::Instant; pub mod app; mod id; @@ -12,7 +12,5 @@ pub struct Channel { pub id: Id, pub name: String, #[serde(skip)] - pub created_at: DateTime, - #[serde(skip)] - pub created_sequence: Sequence, + pub created: Instant, } diff --git a/src/channel/repo.rs b/src/channel/repo.rs index 18cd81f..c000b56 100644 --- a/src/channel/repo.rs +++ b/src/channel/repo.rs @@ -3,7 +3,7 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ channel::{Channel, Id}, clock::DateTime, - event::{types, Sequence}, + event::{types, Instant, Sequence}, }; pub trait Provider { @@ -19,15 +19,9 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Channels<'t>(&'t mut SqliteConnection); impl<'c> Channels<'c> { - pub async fn create( - &mut self, - name: &str, - created_at: &DateTime, - created_sequence: Sequence, - ) -> Result { + pub async fn create(&mut self, name: &str, created: &Instant) -> Result { let id = Id::generate(); - let channel = sqlx::query_as!( - Channel, + let channel = sqlx::query!( r#" insert into channel (id, name, created_at, created_sequence) @@ -40,9 +34,17 @@ impl<'c> Channels<'c> { "#, id, name, - created_at, - created_sequence, + created.at, + created.sequence, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_one(&mut *self.0) .await?; @@ -50,8 +52,7 @@ impl<'c> Channels<'c> { } pub async fn by_id(&mut self, channel: &Id) -> Result { - let channel = sqlx::query_as!( - Channel, + let channel = sqlx::query!( r#" select id as "id: Id", @@ -63,6 +64,14 @@ impl<'c> Channels<'c> { "#, channel, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_one(&mut *self.0) .await?; @@ -73,8 +82,7 @@ impl<'c> Channels<'c> { &mut self, resume_point: Option, ) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, + let channels = sqlx::query!( r#" select id as "id: Id", @@ -87,6 +95,14 @@ impl<'c> Channels<'c> { "#, resume_point, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_all(&mut *self.0) .await?; @@ -97,8 +113,7 @@ impl<'c> Channels<'c> { &mut self, resume_at: Option, ) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, + let channels = sqlx::query!( r#" select id as "id: Id", @@ -110,6 +125,14 @@ impl<'c> Channels<'c> { "#, resume_at, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_all(&mut *self.0) .await?; @@ -119,8 +142,7 @@ impl<'c> Channels<'c> { pub async fn delete( &mut self, channel: &Channel, - deleted_at: &DateTime, - deleted_sequence: Sequence, + deleted: &Instant, ) -> Result { let channel = channel.id.clone(); sqlx::query_scalar!( @@ -135,15 +157,13 @@ impl<'c> Channels<'c> { .await?; Ok(types::ChannelEvent { - sequence: deleted_sequence, - at: *deleted_at, + instant: *deleted, data: types::DeletedEvent { channel }.into(), }) } pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, + let channels = sqlx::query!( r#" select channel.id as "id: Id", @@ -157,6 +177,14 @@ impl<'c> Channels<'c> { "#, expired_at, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_all(&mut *self.0) .await?; diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 6f844cd..33ec3b7 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -52,7 +52,7 @@ async fn messages_in_order() { let events = events.collect::>().immediately().await; for ((sent_at, message), event) in requests.into_iter().zip(events) { - assert_eq!(*sent_at, event.at); + assert_eq!(*sent_at, event.instant.at); assert!(matches!( event.data, types::ChannelEventData::Message(event_message) diff --git a/src/event/app.rs b/src/event/app.rs index 3d35f1a..5e9e79a 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -42,10 +42,10 @@ impl<'a> Events<'a> { .by_id(channel) .await .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; - let sent_sequence = tx.sequence().next().await?; + let sent = tx.sequence().next(sent_at).await?; let event = tx .message_events() - .create(login, &channel, sent_at, sent_sequence, body) + .create(login, &channel, &sent, body) .await?; tx.commit().await?; @@ -62,10 +62,10 @@ impl<'a> Events<'a> { let mut events = Vec::with_capacity(expired.len()); for (channel, message) in expired { - let deleted_sequence = tx.sequence().next().await?; + let deleted = tx.sequence().next(relative_to).await?; let event = tx .message_events() - .delete(&channel, &message, relative_to, deleted_sequence) + .delete(&channel, &message, &deleted) .await?; events.push(event); } @@ -93,7 +93,9 @@ impl<'a> Events<'a> { let channel_events = channels .into_iter() .map(ChannelEvent::created) - .filter(move |event| resume_at.map_or(true, |resume_at| event.sequence > resume_at)); + .filter(move |event| { + resume_at.map_or(true, |resume_at| Sequence::from(event) > resume_at) + }); let message_events = tx.message_events().replay(resume_at).await?; @@ -101,8 +103,8 @@ impl<'a> Events<'a> { .into_iter() .chain(message_events.into_iter()) .collect::>(); - replay_events.sort_by_key(|event| event.sequence); - let resume_live_at = replay_events.last().map(|event| event.sequence); + replay_events.sort_by_key(|event| Sequence::from(event)); + let resume_live_at = replay_events.last().map(Sequence::from); let replay = stream::iter(replay_events); @@ -124,7 +126,7 @@ impl<'a> Events<'a> { fn resume( resume_at: Option, ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { - move |event| future::ready(resume_at < Some(event.sequence)) + move |event| future::ready(resume_at < Some(Sequence::from(event))) } } diff --git a/src/event/mod.rs b/src/event/mod.rs index 7ad3f9c..c982d3a 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -6,4 +6,13 @@ mod routes; mod sequence; pub mod types; +use crate::clock::DateTime; + pub use self::{routes::router, sequence::Sequence}; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Instant { + pub at: DateTime, + #[serde(skip)] + pub sequence: Sequence, +} diff --git a/src/event/repo/message.rs b/src/event/repo/message.rs index f051fec..f29c8a4 100644 --- a/src/event/repo/message.rs +++ b/src/event/repo/message.rs @@ -3,7 +3,7 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ channel::{self, Channel}, clock::DateTime, - event::{types, Sequence}, + event::{types, Instant, Sequence}, login::{self, Login}, message::{self, Message}, }; @@ -25,8 +25,7 @@ impl<'c> Events<'c> { &mut self, sender: &Login, channel: &Channel, - sent_at: &DateTime, - sent_sequence: Sequence, + sent: &Instant, body: &str, ) -> Result { let id = message::Id::generate(); @@ -46,13 +45,15 @@ impl<'c> Events<'c> { id, channel.id, sender.id, - sent_at, - sent_sequence, + sent.at, + sent.sequence, body, ) .map(|row| types::ChannelEvent { - sequence: row.sent_sequence, - at: row.sent_at, + instant: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, data: types::MessageEvent { channel: channel.clone(), sender: sender.clone(), @@ -73,8 +74,7 @@ impl<'c> Events<'c> { &mut self, channel: &Channel, message: &message::Id, - deleted_at: &DateTime, - deleted_sequence: Sequence, + deleted: &Instant, ) -> Result { sqlx::query_scalar!( r#" @@ -88,8 +88,10 @@ impl<'c> Events<'c> { .await?; Ok(types::ChannelEvent { - sequence: deleted_sequence, - at: *deleted_at, + instant: Instant { + at: deleted.at, + sequence: deleted.sequence, + }, data: types::MessageDeletedEvent { channel: channel.clone(), message: message.clone(), @@ -122,8 +124,10 @@ impl<'c> Events<'c> { Channel { id: row.channel_id, name: row.channel_name, - created_at: row.channel_created_at, - created_sequence: row.channel_created_sequence, + created: Instant { + at: row.channel_created_at, + sequence: row.channel_created_sequence, + }, }, row.message, ) @@ -160,14 +164,18 @@ impl<'c> Events<'c> { resume_at, ) .map(|row| types::ChannelEvent { - sequence: row.sent_sequence, - at: row.sent_at, + instant: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, data: types::MessageEvent { channel: Channel { id: row.channel_id, name: row.channel_name, - created_at: row.channel_created_at, - created_sequence: row.channel_created_sequence, + created: Instant { + at: row.channel_created_at, + sequence: row.channel_created_sequence, + }, }, sender: Login { id: row.sender_id, diff --git a/src/event/repo/sequence.rs b/src/event/repo/sequence.rs index c985869..40d6a53 100644 --- a/src/event/repo/sequence.rs +++ b/src/event/repo/sequence.rs @@ -1,6 +1,9 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use crate::event::Sequence; +use crate::{ + clock::DateTime, + event::{Instant, Sequence}, +}; pub trait Provider { fn sequence(&mut self) -> Sequences; @@ -15,7 +18,7 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Sequences<'t>(&'t mut SqliteConnection); impl<'c> Sequences<'c> { - pub async fn next(&mut self) -> Result { + pub async fn next(&mut self, at: &DateTime) -> Result { let next = sqlx::query_scalar!( r#" update event_sequence @@ -26,7 +29,10 @@ impl<'c> Sequences<'c> { .fetch_one(&mut *self.0) .await?; - Ok(next) + Ok(Instant { + at: *at, + sequence: next, + }) } pub async fn current(&mut self) -> Result { diff --git a/src/event/routes.rs b/src/event/routes.rs index 50ac435..c87bfb2 100644 --- a/src/event/routes.rs +++ b/src/event/routes.rs @@ -66,7 +66,7 @@ impl TryFrom for sse::Event { type Error = serde_json::Error; fn try_from(event: types::ChannelEvent) -> Result { - let id = serde_json::to_string(&event.sequence)?; + let id = serde_json::to_string(&Sequence::from(&event))?; let data = serde_json::to_string_pretty(&event)?; let event = Self::default().id(id).data(data); diff --git a/src/event/routes/test.rs b/src/event/routes/test.rs index d1ac3b4..68b55cc 100644 --- a/src/event/routes/test.rs +++ b/src/event/routes/test.rs @@ -6,7 +6,7 @@ use futures::{ }; use crate::{ - event::routes, + event::{routes, Sequence}, test::fixtures::{self, future::Immediately as _}, }; @@ -192,7 +192,7 @@ async fn resumes_from() { assert_eq!(initial_message, event); - event.sequence + Sequence::from(&event) }; // Resume after disconnect @@ -276,7 +276,7 @@ async fn serial_resume() { let event = events.last().expect("this vec is non-empty"); - event.sequence + Sequence::from(event) }; // Resume after disconnect @@ -312,7 +312,7 @@ async fn serial_resume() { let event = events.last().expect("this vec is non-empty"); - event.sequence + Sequence::from(event) }; // Resume after disconnect a second time diff --git a/src/event/types.rs b/src/event/types.rs index cd7dea6..2324dc1 100644 --- a/src/event/types.rs +++ b/src/event/types.rs @@ -1,16 +1,14 @@ use crate::{ channel::{self, Channel}, - clock::DateTime, - event::Sequence, + event::{Instant, Sequence}, login::Login, message::{self, Message}, }; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct ChannelEvent { - #[serde(skip)] - pub sequence: Sequence, - pub at: DateTime, + #[serde(flatten)] + pub instant: Instant, #[serde(flatten)] pub data: ChannelEventData, } @@ -18,25 +16,15 @@ pub struct ChannelEvent { impl ChannelEvent { pub fn created(channel: Channel) -> Self { Self { - at: channel.created_at, - sequence: channel.created_sequence, + instant: channel.created, data: CreatedEvent { channel }.into(), } } - - pub fn channel_id(&self) -> &channel::Id { - match &self.data { - ChannelEventData::Created(event) => &event.channel.id, - ChannelEventData::Message(event) => &event.channel.id, - ChannelEventData::MessageDeleted(event) => &event.channel.id, - ChannelEventData::Deleted(event) => &event.channel, - } - } } impl<'c> From<&'c ChannelEvent> for Sequence { fn from(event: &'c ChannelEvent) -> Self { - event.sequence + event.instant.sequence } } -- cgit v1.2.3 From ec804134c33aedb001c426c5f42f43f53c47848f Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Wed, 2 Oct 2024 12:25:36 -0400 Subject: Represent channels and messages using a split "History" and "Snapshot" model. This separates the code that figures out what happened to an entity from the code that represents it to a user, and makes it easier to compute a snapshot at a point in time (for things like bootstrap). It also makes the internal logic a bit easier to follow, since it's easier to tell whether you're working with a point in time or with the whole recorded history. This hefty. --- ...86018161718e2a6788413bffffb252de3e1959f341.json | 38 ++++ ...4abfbb9e06a5889460743202ca7956acabf006843e.json | 20 ++ ...2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json | 26 +++ ...05a144db4901c302bbcd3e76da6c61742ac44345c9.json | 44 ----- ...0b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json | 32 +++ ...6854398e2151ba2dba10c03a9d2d93184141f1425c.json | 44 ----- ...4eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json | 62 ++++++ ...b8df33083c2da765dfda3023c78c25c06735670457.json | 38 ---- ...72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json | 74 ------- ...df5de8c7d9ad11bb09a0d0d243181d97c79d771071.json | 20 ++ ...6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json | 20 -- ...99e837106c799e84015425286b79f42e4001d8a4c7.json | 62 ++++++ ...76cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json | 20 -- Cargo.lock | 10 + Cargo.toml | 1 + src/app.rs | 5 + src/broadcast.rs | 4 +- src/channel/app.rs | 35 +++- src/channel/event.rs | 48 +++++ src/channel/history.rs | 42 ++++ src/channel/mod.rs | 15 +- src/channel/repo.rs | 94 +++++---- src/channel/routes.rs | 11 +- src/channel/routes/test/on_create.rs | 6 +- src/channel/routes/test/on_send.rs | 13 +- src/channel/snapshot.rs | 38 ++++ src/event/app.rs | 122 +++--------- src/event/broadcaster.rs | 4 +- src/event/mod.rs | 73 ++++++- src/event/repo.rs | 50 +++++ src/event/repo/message.rs | 196 ------------------- src/event/repo/mod.rs | 4 - src/event/repo/sequence.rs | 50 ----- src/event/routes.rs | 14 +- src/event/routes/test.rs | 98 ++++++---- src/event/sequence.rs | 59 ++++++ src/event/types.rs | 85 -------- src/expire.rs | 2 +- src/message/app.rs | 88 +++++++++ src/message/event.rs | 50 +++++ src/message/history.rs | 43 +++++ src/message/mod.rs | 13 +- src/message/repo.rs | 214 +++++++++++++++++++++ src/message/snapshot.rs | 74 +++++++ src/test/fixtures/event.rs | 11 ++ src/test/fixtures/filter.rs | 10 +- src/test/fixtures/message.rs | 13 +- src/test/fixtures/mod.rs | 1 + src/token/app.rs | 4 +- 49 files changed, 1277 insertions(+), 823 deletions(-) create mode 100644 .sqlx/query-1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341.json create mode 100644 .sqlx/query-33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e.json create mode 100644 .sqlx/query-45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json delete mode 100644 .sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json create mode 100644 .sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json delete mode 100644 .sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json create mode 100644 .sqlx/query-5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json delete mode 100644 .sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json delete mode 100644 .sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json create mode 100644 .sqlx/query-b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071.json delete mode 100644 .sqlx/query-d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json create mode 100644 .sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json delete mode 100644 .sqlx/query-f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json create mode 100644 src/channel/event.rs create mode 100644 src/channel/history.rs create mode 100644 src/channel/snapshot.rs create mode 100644 src/event/repo.rs delete mode 100644 src/event/repo/message.rs delete mode 100644 src/event/repo/mod.rs delete mode 100644 src/event/repo/sequence.rs delete mode 100644 src/event/types.rs create mode 100644 src/message/app.rs create mode 100644 src/message/event.rs create mode 100644 src/message/history.rs create mode 100644 src/message/repo.rs create mode 100644 src/message/snapshot.rs create mode 100644 src/test/fixtures/event.rs diff --git a/.sqlx/query-1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341.json b/.sqlx/query-1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341.json new file mode 100644 index 0000000..cc716ed --- /dev/null +++ b/.sqlx/query-1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341.json @@ -0,0 +1,38 @@ +{ + "db_name": "SQLite", + "query": "\n delete from channel\n where id = $1\n returning\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\",\n created_sequence as \"created_sequence: Sequence\"\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341" +} diff --git a/.sqlx/query-33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e.json b/.sqlx/query-33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e.json new file mode 100644 index 0000000..1480953 --- /dev/null +++ b/.sqlx/query-33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n delete from message\n where\n id = $1\n returning 1 as \"deleted: i64\"\n ", + "describe": { + "columns": [ + { + "name": "deleted: i64", + "ordinal": 0, + "type_info": "Null" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + null + ] + }, + "hash": "33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e" +} diff --git a/.sqlx/query-45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json b/.sqlx/query-45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json new file mode 100644 index 0000000..2974cb0 --- /dev/null +++ b/.sqlx/query-45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json @@ -0,0 +1,26 @@ +{ + "db_name": "SQLite", + "query": "\n\t\t\t\tinsert into message\n\t\t\t\t\t(id, channel, sender, sent_at, sent_sequence, body)\n\t\t\t\tvalues ($1, $2, $3, $4, $5, $6)\n\t\t\t\treturning\n\t\t\t\t\tid as \"id: Id\",\n\t\t\t\t\tbody\n\t\t\t", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "body", + "ordinal": 1, + "type_info": "Text" + } + ], + "parameters": { + "Right": 6 + }, + "nullable": [ + false, + false + ] + }, + "hash": "45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f" +} diff --git a/.sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json b/.sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json deleted file mode 100644 index 494e1db..0000000 --- a/.sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json +++ /dev/null @@ -1,44 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n\t\t\t\tinsert into message\n\t\t\t\t\t(id, channel, sender, sent_at, sent_sequence, body)\n\t\t\t\tvalues ($1, $2, $3, $4, $5, $6)\n\t\t\t\treturning\n\t\t\t\t\tid as \"id: message::Id\",\n\t\t\t\t\tsender as \"sender: login::Id\",\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\",\n\t\t\t\t\tbody\n\t\t\t", - "describe": { - "columns": [ - { - "name": "id: message::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "sender: login::Id", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "sent_at: DateTime", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "sent_sequence: Sequence", - "ordinal": 3, - "type_info": "Integer" - }, - { - "name": "body", - "ordinal": 4, - "type_info": "Text" - } - ], - "parameters": { - "Right": 6 - }, - "nullable": [ - false, - false, - false, - false, - false - ] - }, - "hash": "4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9" -} diff --git a/.sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json b/.sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json new file mode 100644 index 0000000..fb5f94b --- /dev/null +++ b/.sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json @@ -0,0 +1,32 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n message.id as \"message: Id\"\n from message\n join channel on message.channel = channel.id\n where sent_at < $1\n ", + "describe": { + "columns": [ + { + "name": "channel_id: channel::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "message: Id", + "ordinal": 2, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc" +} diff --git a/.sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json b/.sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json deleted file mode 100644 index 820b43f..0000000 --- a/.sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json +++ /dev/null @@ -1,44 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n channel.created_at as \"channel_created_at: DateTime\",\n channel.created_sequence as \"channel_created_sequence: Sequence\",\n message.id as \"message: message::Id\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where sent_at < $1\n ", - "describe": { - "columns": [ - { - "name": "channel_id: channel::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "channel_name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "channel_created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "channel_created_sequence: Sequence", - "ordinal": 3, - "type_info": "Integer" - }, - { - "name": "message: message::Id", - "ordinal": 4, - "type_info": "Text" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false, - false - ] - }, - "hash": "5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c" -} diff --git a/.sqlx/query-5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json b/.sqlx/query-5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json new file mode 100644 index 0000000..4ca6786 --- /dev/null +++ b/.sqlx/query-5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json @@ -0,0 +1,62 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n sender.id as \"sender_id: login::Id\",\n sender.name as \"sender_name\",\n message.id as \"id: Id\",\n message.body,\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where coalesce(message.sent_sequence > $1, true)\n ", + "describe": { + "columns": [ + { + "name": "channel_id: channel::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "sender_id: login::Id", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "sender_name", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "id: Id", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "body", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "sent_at: DateTime", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "sent_sequence: Sequence", + "ordinal": 7, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3" +} diff --git a/.sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json b/.sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json deleted file mode 100644 index b34443f..0000000 --- a/.sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n channel.id as \"id: Id\",\n channel.name,\n channel.created_at as \"created_at: DateTime\",\n channel.created_sequence as \"created_sequence: Sequence\"\n from channel\n left join message\n where created_at < $1\n and message.id is null\n ", - "describe": { - "columns": [ - { - "name": "id: Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "created_sequence: Sequence", - "ordinal": 3, - "type_info": "Integer" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false - ] - }, - "hash": "74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457" -} diff --git a/.sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json b/.sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json deleted file mode 100644 index f546438..0000000 --- a/.sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json +++ /dev/null @@ -1,74 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n\t\t\t\tselect\n\t\t\t\t\tmessage.id as \"id: message::Id\",\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n channel.created_at as \"channel_created_at: DateTime\",\n channel.created_sequence as \"channel_created_sequence: Sequence\",\n\t\t\t\t\tsender.id as \"sender_id: login::Id\",\n\t\t\t\t\tsender.name as sender_name,\n message.sent_at as \"sent_at: DateTime\",\n message.sent_sequence as \"sent_sequence: Sequence\",\n message.body\n\t\t\t\tfrom message\n join channel on message.channel = channel.id\n\t\t\t\t\tjoin login as sender on message.sender = sender.id\n\t\t\t\twhere coalesce(message.sent_sequence > $1, true)\n\t\t\t\torder by sent_sequence asc\n\t\t\t", - "describe": { - "columns": [ - { - "name": "id: message::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "channel_id: channel::Id", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "channel_name", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "channel_created_at: DateTime", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "channel_created_sequence: Sequence", - "ordinal": 4, - "type_info": "Integer" - }, - { - "name": "sender_id: login::Id", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "sender_name", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "sent_at: DateTime", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "sent_sequence: Sequence", - "ordinal": 8, - "type_info": "Integer" - }, - { - "name": "body", - "ordinal": 9, - "type_info": "Text" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - false - ] - }, - "hash": "7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c" -} diff --git a/.sqlx/query-b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071.json b/.sqlx/query-b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071.json new file mode 100644 index 0000000..b82727f --- /dev/null +++ b/.sqlx/query-b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"id: Id\"\n from channel\n left join message\n where created_at < $1\n and message.id is null\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false + ] + }, + "hash": "b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071" +} diff --git a/.sqlx/query-d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json b/.sqlx/query-d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json deleted file mode 100644 index 1d448d4..0000000 --- a/.sqlx/query-d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n delete from channel\n where id = $1\n returning 1 as \"row: i64\"\n ", - "describe": { - "columns": [ - { - "name": "row: i64", - "ordinal": 0, - "type_info": "Null" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - null - ] - }, - "hash": "d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb" -} diff --git a/.sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json b/.sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json new file mode 100644 index 0000000..288a657 --- /dev/null +++ b/.sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json @@ -0,0 +1,62 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n sender.id as \"sender_id: login::Id\",\n sender.name as \"sender_name\",\n message.id as \"id: Id\",\n message.body,\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where message.id = $1\n and message.channel = $2\n ", + "describe": { + "columns": [ + { + "name": "channel_id: channel::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "sender_id: login::Id", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "sender_name", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "id: Id", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "body", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "sent_at: DateTime", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "sent_sequence: Sequence", + "ordinal": 7, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 2 + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7" +} diff --git a/.sqlx/query-f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json b/.sqlx/query-f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json deleted file mode 100644 index 7b1d2d8..0000000 --- a/.sqlx/query-f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n delete from message\n where id = $1\n returning 1 as \"row: i64\"\n ", - "describe": { - "columns": [ - { - "name": "row: i64", - "ordinal": 0, - "type_info": "Null" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - null - ] - }, - "hash": "f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9" -} diff --git a/Cargo.lock b/Cargo.lock index aa2120a..b1f0582 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -787,6 +787,7 @@ dependencies = [ "faker_rand", "futures", "headers", + "itertools", "password-hash", "rand", "rand_core", @@ -956,6 +957,15 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" diff --git a/Cargo.toml b/Cargo.toml index cb46b41..2b2e774 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ chrono = { version = "0.4.38", features = ["serde"] } clap = { version = "4.5.18", features = ["derive", "env"] } futures = "0.3.30" headers = "0.4.0" +itertools = "0.13.0" password-hash = { version = "0.5.0", features = ["std"] } rand = "0.8.5" rand_core = { version = "0.6.4", features = ["getrandom"] } diff --git a/src/app.rs b/src/app.rs index 5542e5f..186e5f8 100644 --- a/src/app.rs +++ b/src/app.rs @@ -4,6 +4,7 @@ use crate::{ channel::app::Channels, event::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, login::app::Logins, + message::app::Messages, token::{app::Tokens, broadcaster::Broadcaster as TokenBroadcaster}, }; @@ -35,6 +36,10 @@ impl App { Logins::new(&self.db) } + pub const fn messages(&self) -> Messages { + Messages::new(&self.db, &self.events) + } + pub const fn tokens(&self) -> Tokens { Tokens::new(&self.db, &self.tokens) } diff --git a/src/broadcast.rs b/src/broadcast.rs index 083a301..bedc263 100644 --- a/src/broadcast.rs +++ b/src/broadcast.rs @@ -32,7 +32,7 @@ where { // panic: if ``message.channel.id`` has not been previously registered, // and was not part of the initial set of channels. - pub fn broadcast(&self, message: &M) { + pub fn broadcast(&self, message: impl Into) { let tx = self.sender(); // Per the Tokio docs, the returned error is only used to indicate that @@ -42,7 +42,7 @@ where // // The successful return value, which includes the number of active // receivers, also isn't that interesting to us. - let _ = tx.send(message.clone()); + let _ = tx.send(message.into()); } // panic: if ``channel`` has not been previously registered, and was not diff --git a/src/channel/app.rs b/src/channel/app.rs index b7e3a10..6ce826b 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -1,10 +1,11 @@ use chrono::TimeDelta; +use itertools::Itertools; use sqlx::sqlite::SqlitePool; use crate::{ channel::{repo::Provider as _, Channel}, clock::DateTime, - event::{broadcaster::Broadcaster, repo::Provider as _, types::ChannelEvent, Sequence}, + event::{broadcaster::Broadcaster, repo::Provider as _, Sequence}, }; pub struct Channels<'a> { @@ -27,10 +28,11 @@ impl<'a> Channels<'a> { .map_err(|err| CreateError::from_duplicate_name(err, name))?; tx.commit().await?; - self.events - .broadcast(&ChannelEvent::created(channel.clone())); + for event in channel.events() { + self.events.broadcast(event); + } - Ok(channel) + Ok(channel.snapshot()) } pub async fn all(&self, resume_point: Option) -> Result, InternalError> { @@ -38,6 +40,16 @@ impl<'a> Channels<'a> { let channels = tx.channels().all(resume_point).await?; tx.commit().await?; + let channels = channels + .into_iter() + .filter_map(|channel| { + channel + .events() + .filter(Sequence::up_to(resume_point)) + .collect() + }) + .collect(); + Ok(channels) } @@ -51,14 +63,21 @@ impl<'a> Channels<'a> { let mut events = Vec::with_capacity(expired.len()); for channel in expired { let deleted = tx.sequence().next(relative_to).await?; - let event = tx.channels().delete(&channel, &deleted).await?; - events.push(event); + let channel = tx.channels().delete(&channel, &deleted).await?; + events.push( + channel + .events() + .filter(Sequence::start_from(deleted.sequence)), + ); } tx.commit().await?; - for event in events { - self.events.broadcast(&event); + for event in events + .into_iter() + .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + { + self.events.broadcast(event); } Ok(()) diff --git a/src/channel/event.rs b/src/channel/event.rs new file mode 100644 index 0000000..9c54174 --- /dev/null +++ b/src/channel/event.rs @@ -0,0 +1,48 @@ +use super::Channel; +use crate::{ + channel, + event::{Instant, Sequenced}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Event { + #[serde(flatten)] + pub instant: Instant, + #[serde(flatten)] + pub kind: Kind, +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + self.instant + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Kind { + Created(Created), + Deleted(Deleted), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Created { + pub channel: Channel, +} + +impl From for Kind { + fn from(event: Created) -> Self { + Self::Created(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Deleted { + pub channel: channel::Id, +} + +impl From for Kind { + fn from(event: Deleted) -> Self { + Self::Deleted(event) + } +} diff --git a/src/channel/history.rs b/src/channel/history.rs new file mode 100644 index 0000000..3cc7d9d --- /dev/null +++ b/src/channel/history.rs @@ -0,0 +1,42 @@ +use super::{ + event::{Created, Deleted, Event}, + Channel, +}; +use crate::event::Instant; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct History { + pub channel: Channel, + pub created: Instant, + pub deleted: Option, +} + +impl History { + fn created(&self) -> Event { + Event { + instant: self.created, + kind: Created { + channel: self.channel.clone(), + } + .into(), + } + } + + fn deleted(&self) -> Option { + self.deleted.map(|instant| Event { + instant, + kind: Deleted { + channel: self.channel.id.clone(), + } + .into(), + }) + } + + pub fn events(&self) -> impl Iterator { + [self.created()].into_iter().chain(self.deleted()) + } + + pub fn snapshot(&self) -> Channel { + self.channel.clone() + } +} diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 4baa7e3..eb8200b 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,16 +1,9 @@ -use crate::event::Instant; - pub mod app; +pub mod event; +mod history; mod id; pub mod repo; mod routes; +mod snapshot; -pub use self::{id::Id, routes::router}; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Channel { - pub id: Id, - pub name: String, - #[serde(skip)] - pub created: Instant, -} +pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Channel}; diff --git a/src/channel/repo.rs b/src/channel/repo.rs index c000b56..8bb761b 100644 --- a/src/channel/repo.rs +++ b/src/channel/repo.rs @@ -1,9 +1,9 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ - channel::{Channel, Id}, + channel::{Channel, History, Id}, clock::DateTime, - event::{types, Instant, Sequence}, + event::{Instant, Sequence}, }; pub trait Provider { @@ -19,7 +19,7 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Channels<'t>(&'t mut SqliteConnection); impl<'c> Channels<'c> { - pub async fn create(&mut self, name: &str, created: &Instant) -> Result { + pub async fn create(&mut self, name: &str, created: &Instant) -> Result { let id = Id::generate(); let channel = sqlx::query!( r#" @@ -37,13 +37,16 @@ impl<'c> Channels<'c> { created.at, created.sequence, ) - .map(|row| Channel { - id: row.id, - name: row.name, + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, created: Instant { at: row.created_at, sequence: row.created_sequence, }, + deleted: None, }) .fetch_one(&mut *self.0) .await?; @@ -51,7 +54,7 @@ impl<'c> Channels<'c> { Ok(channel) } - pub async fn by_id(&mut self, channel: &Id) -> Result { + pub async fn by_id(&mut self, channel: &Id) -> Result { let channel = sqlx::query!( r#" select @@ -64,13 +67,16 @@ impl<'c> Channels<'c> { "#, channel, ) - .map(|row| Channel { - id: row.id, - name: row.name, + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, created: Instant { at: row.created_at, sequence: row.created_sequence, }, + deleted: None, }) .fetch_one(&mut *self.0) .await?; @@ -81,7 +87,7 @@ impl<'c> Channels<'c> { pub async fn all( &mut self, resume_point: Option, - ) -> Result, sqlx::Error> { + ) -> Result, sqlx::Error> { let channels = sqlx::query!( r#" select @@ -95,13 +101,16 @@ impl<'c> Channels<'c> { "#, resume_point, ) - .map(|row| Channel { - id: row.id, - name: row.name, + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, created: Instant { at: row.created_at, sequence: row.created_sequence, }, + deleted: None, }) .fetch_all(&mut *self.0) .await?; @@ -112,7 +121,7 @@ impl<'c> Channels<'c> { pub async fn replay( &mut self, resume_at: Option, - ) -> Result, sqlx::Error> { + ) -> Result, sqlx::Error> { let channels = sqlx::query!( r#" select @@ -125,13 +134,16 @@ impl<'c> Channels<'c> { "#, resume_at, ) - .map(|row| Channel { - id: row.id, - name: row.name, + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, created: Instant { at: row.created_at, sequence: row.created_sequence, }, + deleted: None, }) .fetch_all(&mut *self.0) .await?; @@ -141,35 +153,43 @@ impl<'c> Channels<'c> { pub async fn delete( &mut self, - channel: &Channel, + channel: &Id, deleted: &Instant, - ) -> Result { - let channel = channel.id.clone(); - sqlx::query_scalar!( + ) -> Result { + let channel = sqlx::query!( r#" delete from channel where id = $1 - returning 1 as "row: i64" + returning + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" "#, channel, ) + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + deleted: Some(*deleted), + }) .fetch_one(&mut *self.0) .await?; - Ok(types::ChannelEvent { - instant: *deleted, - data: types::DeletedEvent { channel }.into(), - }) + Ok(channel) } - pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { - let channels = sqlx::query!( + pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { + let channels = sqlx::query_scalar!( r#" select - channel.id as "id: Id", - channel.name, - channel.created_at as "created_at: DateTime", - channel.created_sequence as "created_sequence: Sequence" + channel.id as "id: Id" from channel left join message where created_at < $1 @@ -177,14 +197,6 @@ impl<'c> Channels<'c> { "#, expired_at, ) - .map(|row| Channel { - id: row.id, - name: row.name, - created: Instant { - at: row.created_at, - sequence: row.created_sequence, - }, - }) .fetch_all(&mut *self.0) .await?; diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 5d8b61e..5bb1ee9 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -13,8 +13,9 @@ use crate::{ channel::{self, Channel}, clock::RequestedAt, error::Internal, - event::{app::EventsError, Sequence}, + event::Sequence, login::Login, + message::app::Error as MessageError, }; #[cfg(test)] @@ -99,8 +100,8 @@ async fn on_send( login: Login, Json(request): Json, ) -> Result { - app.events() - .send(&login, &channel, &request.message, &sent_at) + app.messages() + .send(&channel, &login, &sent_at, &request.message) .await // Could impl `From` here, but it's more code and this is used once. .map_err(ErrorResponse)?; @@ -109,13 +110,13 @@ async fn on_send( } #[derive(Debug)] -struct ErrorResponse(EventsError); +struct ErrorResponse(MessageError); impl IntoResponse for ErrorResponse { fn into_response(self) -> Response { let Self(error) = self; match error { - not_found @ EventsError::ChannelNotFound(_) => { + not_found @ MessageError::ChannelNotFound(_) => { (StatusCode::NOT_FOUND, not_found.to_string()).into_response() } other => Internal::from(other).into_response(), diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index 9988932..5733c9e 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -3,7 +3,7 @@ use futures::stream::StreamExt as _; use crate::{ channel::{app, routes}, - event::types, + event, test::fixtures::{self, future::Immediately as _}, }; @@ -50,8 +50,8 @@ async fn new_channel() { .expect("creation event published"); assert!(matches!( - event.data, - types::ChannelEventData::Created(event) + event.kind, + event::Kind::ChannelCreated(event) if event.channel == response_channel )); } diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 33ec3b7..1027b29 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -4,7 +4,8 @@ use futures::stream::StreamExt; use crate::{ channel, channel::routes, - event::{app, types}, + event, + message::app, test::fixtures::{self, future::Immediately as _}, }; @@ -54,10 +55,10 @@ async fn messages_in_order() { for ((sent_at, message), event) in requests.into_iter().zip(events) { assert_eq!(*sent_at, event.instant.at); assert!(matches!( - event.data, - types::ChannelEventData::Message(event_message) - if event_message.sender == sender - && event_message.message.body == message + event.kind, + event::Kind::MessageSent(event) + if event.message.sender == sender + && event.message.body == message )); } } @@ -90,6 +91,6 @@ async fn nonexistent_channel() { assert!(matches!( error, - app::EventsError::ChannelNotFound(error_channel) if channel == error_channel + app::Error::ChannelNotFound(error_channel) if channel == error_channel )); } diff --git a/src/channel/snapshot.rs b/src/channel/snapshot.rs new file mode 100644 index 0000000..6462f25 --- /dev/null +++ b/src/channel/snapshot.rs @@ -0,0 +1,38 @@ +use super::{ + event::{Created, Event, Kind}, + Id, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Channel { + pub id: Id, + pub name: String, +} + +impl Channel { + fn apply(state: Option, event: Event) -> Option { + match (state, event.kind) { + (None, Kind::Created(event)) => Some(event.into()), + (Some(channel), Kind::Deleted(event)) if channel.id == event.channel => None, + (state, event) => panic!("invalid channel event {event:#?} for state {state:#?}"), + } + } +} + +impl FromIterator for Option { + fn from_iter>(events: I) -> Self { + events.into_iter().fold(None, Channel::apply) + } +} + +impl From<&Created> for Channel { + fn from(event: &Created) -> Self { + event.channel.clone() + } +} + +impl From for Channel { + fn from(event: Created) -> Self { + event.channel + } +} diff --git a/src/event/app.rs b/src/event/app.rs index 5e9e79a..e58bea9 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -1,22 +1,15 @@ -use chrono::TimeDelta; use futures::{ future, stream::{self, StreamExt as _}, Stream, }; +use itertools::Itertools as _; use sqlx::sqlite::SqlitePool; -use super::{ - broadcaster::Broadcaster, - repo::message::Provider as _, - types::{self, ChannelEvent}, -}; +use super::{broadcaster::Broadcaster, Event, Sequence, Sequenced}; use crate::{ channel::{self, repo::Provider as _}, - clock::DateTime, - db::NotFound as _, - event::{repo::Provider as _, Sequence}, - login::Login, + message::{self, repo::Provider as _}, }; pub struct Events<'a> { @@ -29,111 +22,52 @@ impl<'a> Events<'a> { Self { db, events } } - pub async fn send( - &self, - login: &Login, - channel: &channel::Id, - body: &str, - sent_at: &DateTime, - ) -> Result { - let mut tx = self.db.begin().await?; - let channel = tx - .channels() - .by_id(channel) - .await - .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; - let sent = tx.sequence().next(sent_at).await?; - let event = tx - .message_events() - .create(login, &channel, &sent, body) - .await?; - tx.commit().await?; - - self.events.broadcast(&event); - Ok(event) - } - - pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { - // Somewhat arbitrarily, expire after 90 days. - let expire_at = relative_to.to_owned() - TimeDelta::days(90); - - let mut tx = self.db.begin().await?; - let expired = tx.message_events().expired(&expire_at).await?; - - let mut events = Vec::with_capacity(expired.len()); - for (channel, message) in expired { - let deleted = tx.sequence().next(relative_to).await?; - let event = tx - .message_events() - .delete(&channel, &message, &deleted) - .await?; - events.push(event); - } - - tx.commit().await?; - - for event in events { - self.events.broadcast(&event); - } - - Ok(()) - } - pub async fn subscribe( &self, resume_at: Option, - ) -> Result + std::fmt::Debug, sqlx::Error> { + ) -> Result + std::fmt::Debug, sqlx::Error> { // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. let live_messages = self.events.subscribe(); let mut tx = self.db.begin().await?; - let channels = tx.channels().replay(resume_at).await?; + let channels = tx.channels().replay(resume_at).await?; let channel_events = channels - .into_iter() - .map(ChannelEvent::created) - .filter(move |event| { - resume_at.map_or(true, |resume_at| Sequence::from(event) > resume_at) - }); - - let message_events = tx.message_events().replay(resume_at).await?; - - let mut replay_events = channel_events - .into_iter() - .chain(message_events.into_iter()) + .iter() + .map(channel::History::events) + .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + .filter(Sequence::after(resume_at)) + .map(Event::from); + + let messages = tx.messages().replay(resume_at).await?; + let message_events = messages + .iter() + .map(message::History::events) + .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + .filter(Sequence::after(resume_at)) + .map(Event::from); + + let replay_events = channel_events + .merge_by(message_events, |a, b| { + a.instant.sequence < b.instant.sequence + }) .collect::>(); - replay_events.sort_by_key(|event| Sequence::from(event)); - let resume_live_at = replay_events.last().map(Sequence::from); + let resume_live_at = replay_events.last().map(Sequenced::sequence); let replay = stream::iter(replay_events); - // no skip_expired or resume transforms for stored_messages, as it's - // constructed not to contain messages meeting either criterion. - // - // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; - // * resume is redundant with the resume_at argument to - // `tx.broadcasts().replay(…)`. let live_messages = live_messages // Filtering on the broadcast resume point filters out messages // before resume_at, and filters out messages duplicated from - // stored_messages. + // `replay_events`. .filter(Self::resume(resume_live_at)); Ok(replay.chain(live_messages)) } - fn resume( - resume_at: Option, - ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { - move |event| future::ready(resume_at < Some(Sequence::from(event))) + fn resume(resume_at: Option) -> impl for<'m> FnMut(&'m Event) -> future::Ready { + let filter = Sequence::after(resume_at); + move |event| future::ready(filter(event)) } } - -#[derive(Debug, thiserror::Error)] -pub enum EventsError { - #[error("channel {0} not found")] - ChannelNotFound(channel::Id), - #[error(transparent)] - DatabaseError(#[from] sqlx::Error), -} diff --git a/src/event/broadcaster.rs b/src/event/broadcaster.rs index 92f631f..de2513a 100644 --- a/src/event/broadcaster.rs +++ b/src/event/broadcaster.rs @@ -1,3 +1,3 @@ -use crate::{broadcast, event::types}; +use crate::broadcast; -pub type Broadcaster = broadcast::Broadcaster; +pub type Broadcaster = broadcast::Broadcaster; diff --git a/src/event/mod.rs b/src/event/mod.rs index c982d3a..1503b77 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -1,18 +1,75 @@ +use crate::{channel, message}; + pub mod app; pub mod broadcaster; mod extract; pub mod repo; mod routes; mod sequence; -pub mod types; -use crate::clock::DateTime; +pub use self::{ + routes::router, + sequence::{Instant, Sequence, Sequenced}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Event { + #[serde(flatten)] + pub instant: Instant, + #[serde(flatten)] + pub kind: Kind, +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + self.instant + } +} + +impl From for Event { + fn from(event: channel::Event) -> Self { + Self { + instant: event.instant, + kind: event.kind.into(), + } + } +} + +impl From for Event { + fn from(event: message::Event) -> Self { + Self { + instant: event.instant, + kind: event.kind.into(), + } + } +} -pub use self::{routes::router, sequence::Sequence}; +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Kind { + #[serde(rename = "created")] + ChannelCreated(channel::event::Created), + #[serde(rename = "message")] + MessageSent(message::event::Sent), + MessageDeleted(message::event::Deleted), + #[serde(rename = "deleted")] + ChannelDeleted(channel::event::Deleted), +} + +impl From for Kind { + fn from(kind: channel::event::Kind) -> Self { + match kind { + channel::event::Kind::Created(created) => Self::ChannelCreated(created), + channel::event::Kind::Deleted(deleted) => Self::ChannelDeleted(deleted), + } + } +} -#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Instant { - pub at: DateTime, - #[serde(skip)] - pub sequence: Sequence, +impl From for Kind { + fn from(kind: message::event::Kind) -> Self { + match kind { + message::event::Kind::Sent(created) => Self::MessageSent(created), + message::event::Kind::Deleted(deleted) => Self::MessageDeleted(deleted), + } + } } diff --git a/src/event/repo.rs b/src/event/repo.rs new file mode 100644 index 0000000..40d6a53 --- /dev/null +++ b/src/event/repo.rs @@ -0,0 +1,50 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + clock::DateTime, + event::{Instant, Sequence}, +}; + +pub trait Provider { + fn sequence(&mut self) -> Sequences; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn sequence(&mut self) -> Sequences { + Sequences(self) + } +} + +pub struct Sequences<'t>(&'t mut SqliteConnection); + +impl<'c> Sequences<'c> { + pub async fn next(&mut self, at: &DateTime) -> Result { + let next = sqlx::query_scalar!( + r#" + update event_sequence + set last_value = last_value + 1 + returning last_value as "next_value: Sequence" + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(Instant { + at: *at, + sequence: next, + }) + } + + pub async fn current(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + select last_value as "last_value: Sequence" + from event_sequence + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } +} diff --git a/src/event/repo/message.rs b/src/event/repo/message.rs deleted file mode 100644 index f29c8a4..0000000 --- a/src/event/repo/message.rs +++ /dev/null @@ -1,196 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - channel::{self, Channel}, - clock::DateTime, - event::{types, Instant, Sequence}, - login::{self, Login}, - message::{self, Message}, -}; - -pub trait Provider { - fn message_events(&mut self) -> Events; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn message_events(&mut self) -> Events { - Events(self) - } -} - -pub struct Events<'t>(&'t mut SqliteConnection); - -impl<'c> Events<'c> { - pub async fn create( - &mut self, - sender: &Login, - channel: &Channel, - sent: &Instant, - body: &str, - ) -> Result { - let id = message::Id::generate(); - - let message = sqlx::query!( - r#" - insert into message - (id, channel, sender, sent_at, sent_sequence, body) - values ($1, $2, $3, $4, $5, $6) - returning - id as "id: message::Id", - sender as "sender: login::Id", - sent_at as "sent_at: DateTime", - sent_sequence as "sent_sequence: Sequence", - body - "#, - id, - channel.id, - sender.id, - sent.at, - sent.sequence, - body, - ) - .map(|row| types::ChannelEvent { - instant: Instant { - at: row.sent_at, - sequence: row.sent_sequence, - }, - data: types::MessageEvent { - channel: channel.clone(), - sender: sender.clone(), - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_one(&mut *self.0) - .await?; - - Ok(message) - } - - pub async fn delete( - &mut self, - channel: &Channel, - message: &message::Id, - deleted: &Instant, - ) -> Result { - sqlx::query_scalar!( - r#" - delete from message - where id = $1 - returning 1 as "row: i64" - "#, - message, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(types::ChannelEvent { - instant: Instant { - at: deleted.at, - sequence: deleted.sequence, - }, - data: types::MessageDeletedEvent { - channel: channel.clone(), - message: message.clone(), - } - .into(), - }) - } - - pub async fn expired( - &mut self, - expire_at: &DateTime, - ) -> Result, sqlx::Error> { - let messages = sqlx::query!( - r#" - select - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - channel.created_at as "channel_created_at: DateTime", - channel.created_sequence as "channel_created_sequence: Sequence", - message.id as "message: message::Id" - from message - join channel on message.channel = channel.id - join login as sender on message.sender = sender.id - where sent_at < $1 - "#, - expire_at, - ) - .map(|row| { - ( - Channel { - id: row.channel_id, - name: row.channel_name, - created: Instant { - at: row.channel_created_at, - sequence: row.channel_created_sequence, - }, - }, - row.message, - ) - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(messages) - } - - pub async fn replay( - &mut self, - resume_at: Option, - ) -> Result, sqlx::Error> { - let events = sqlx::query!( - r#" - select - message.id as "id: message::Id", - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - channel.created_at as "channel_created_at: DateTime", - channel.created_sequence as "channel_created_sequence: Sequence", - sender.id as "sender_id: login::Id", - sender.name as sender_name, - message.sent_at as "sent_at: DateTime", - message.sent_sequence as "sent_sequence: Sequence", - message.body - from message - join channel on message.channel = channel.id - join login as sender on message.sender = sender.id - where coalesce(message.sent_sequence > $1, true) - order by sent_sequence asc - "#, - resume_at, - ) - .map(|row| types::ChannelEvent { - instant: Instant { - at: row.sent_at, - sequence: row.sent_sequence, - }, - data: types::MessageEvent { - channel: Channel { - id: row.channel_id, - name: row.channel_name, - created: Instant { - at: row.channel_created_at, - sequence: row.channel_created_sequence, - }, - }, - sender: Login { - id: row.sender_id, - name: row.sender_name, - }, - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(events) - } -} diff --git a/src/event/repo/mod.rs b/src/event/repo/mod.rs deleted file mode 100644 index cee840c..0000000 --- a/src/event/repo/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod message; -mod sequence; - -pub use self::sequence::Provider; diff --git a/src/event/repo/sequence.rs b/src/event/repo/sequence.rs deleted file mode 100644 index 40d6a53..0000000 --- a/src/event/repo/sequence.rs +++ /dev/null @@ -1,50 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - clock::DateTime, - event::{Instant, Sequence}, -}; - -pub trait Provider { - fn sequence(&mut self) -> Sequences; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn sequence(&mut self) -> Sequences { - Sequences(self) - } -} - -pub struct Sequences<'t>(&'t mut SqliteConnection); - -impl<'c> Sequences<'c> { - pub async fn next(&mut self, at: &DateTime) -> Result { - let next = sqlx::query_scalar!( - r#" - update event_sequence - set last_value = last_value + 1 - returning last_value as "next_value: Sequence" - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(Instant { - at: *at, - sequence: next, - }) - } - - pub async fn current(&mut self) -> Result { - let next = sqlx::query_scalar!( - r#" - select last_value as "last_value: Sequence" - from event_sequence - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) - } -} diff --git a/src/event/routes.rs b/src/event/routes.rs index c87bfb2..5b9c7e3 100644 --- a/src/event/routes.rs +++ b/src/event/routes.rs @@ -10,11 +10,11 @@ use axum::{ use axum_extra::extract::Query; use futures::stream::{Stream, StreamExt as _}; -use super::{extract::LastEventId, types}; +use super::{extract::LastEventId, Event}; use crate::{ app::App, error::{Internal, Unauthorized}, - event::Sequence, + event::{Sequence, Sequenced as _}, token::{app::ValidateError, extract::Identity}, }; @@ -35,7 +35,7 @@ async fn events( identity: Identity, last_event_id: Option>, Query(query): Query, -) -> Result + std::fmt::Debug>, EventsError> { +) -> Result + std::fmt::Debug>, EventsError> { let resume_at = last_event_id .map(LastEventId::into_inner) .or(query.resume_point); @@ -51,7 +51,7 @@ struct Events(S); impl IntoResponse for Events where - S: Stream + Send + 'static, + S: Stream + Send + 'static, { fn into_response(self) -> Response { let Self(stream) = self; @@ -62,11 +62,11 @@ where } } -impl TryFrom for sse::Event { +impl TryFrom for sse::Event { type Error = serde_json::Error; - fn try_from(event: types::ChannelEvent) -> Result { - let id = serde_json::to_string(&Sequence::from(&event))?; + fn try_from(event: Event) -> Result { + let id = serde_json::to_string(&event.sequence())?; let data = serde_json::to_string_pretty(&event)?; let event = Self::default().id(id).data(data); diff --git a/src/event/routes/test.rs b/src/event/routes/test.rs index 68b55cc..ba9953e 100644 --- a/src/event/routes/test.rs +++ b/src/event/routes/test.rs @@ -6,7 +6,7 @@ use futures::{ }; use crate::{ - event::{routes, Sequence}, + event::{routes, Sequenced as _}, test::fixtures::{self, future::Immediately as _}, }; @@ -17,7 +17,7 @@ async fn includes_historical_message() { let app = fixtures::scratch_app().await; let sender = fixtures::login::create(&app).await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; // Call the endpoint @@ -36,7 +36,7 @@ async fn includes_historical_message() { .await .expect("delivered stored message"); - assert_eq!(message, event); + assert!(fixtures::event::message_sent(&event, &message)); } #[tokio::test] @@ -58,7 +58,7 @@ async fn includes_live_message() { // Verify the semantics let sender = fixtures::login::create(&app).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; let event = events .filter(fixtures::filter::messages()) @@ -67,7 +67,7 @@ async fn includes_live_message() { .await .expect("delivered live message"); - assert_eq!(message, event); + assert!(fixtures::event::message_sent(&event, &message)); } #[tokio::test] @@ -87,7 +87,7 @@ async fn includes_multiple_channels() { let app = app.clone(); let sender = sender.clone(); let channel = channel.clone(); - async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } + async move { fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await } }) .collect::>() .await; @@ -110,7 +110,9 @@ async fn includes_multiple_channels() { .await; for message in &messages { - assert!(events.iter().any(|event| { event == message })); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } } @@ -123,9 +125,9 @@ async fn sequential_messages() { let sender = fixtures::login::create(&app).await; let messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; // Call the endpoint @@ -138,7 +140,13 @@ async fn sequential_messages() { // Verify the structure of the response. - let mut events = events.filter(|event| future::ready(messages.contains(event))); + let mut events = events.filter(|event| { + future::ready( + messages + .iter() + .any(|message| fixtures::event::message_sent(event, message)), + ) + }); // Verify delivery in order for message in &messages { @@ -148,7 +156,7 @@ async fn sequential_messages() { .await .expect("undelivered messages remaining"); - assert_eq!(message, &event); + assert!(fixtures::event::message_sent(&event, message)); } } @@ -160,11 +168,11 @@ async fn resumes_from() { let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app).await; - let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + let initial_message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; let later_messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; // Call the endpoint @@ -190,9 +198,9 @@ async fn resumes_from() { .await .expect("delivered events"); - assert_eq!(initial_message, event); + assert!(fixtures::event::message_sent(&event, &initial_message)); - Sequence::from(&event) + event.sequence() }; // Resume after disconnect @@ -214,7 +222,9 @@ async fn resumes_from() { .await; for message in &later_messages { - assert!(events.iter().any(|event| event == message)); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } } @@ -249,8 +259,8 @@ async fn serial_resume() { let resume_at = { let initial_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel_b, &sender, &fixtures::now()).await, ]; // First subscription @@ -271,12 +281,14 @@ async fn serial_resume() { .await; for message in &initial_messages { - assert!(events.iter().any(|event| event == message)); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } let event = events.last().expect("this vec is non-empty"); - Sequence::from(event) + event.sequence() }; // Resume after disconnect @@ -285,8 +297,8 @@ async fn serial_resume() { // Note that channel_b does not appear here. The buggy behaviour // would be masked if channel_b happened to send a new message // into the resumed event stream. - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, ]; // Second subscription @@ -307,12 +319,14 @@ async fn serial_resume() { .await; for message in &resume_messages { - assert!(events.iter().any(|event| event == message)); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } let event = events.last().expect("this vec is non-empty"); - Sequence::from(event) + event.sequence() }; // Resume after disconnect a second time @@ -321,8 +335,8 @@ async fn serial_resume() { // problem. The resume point should before both of these messages, but // after _all_ prior messages. let final_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel_b, &sender, &fixtures::now()).await, ]; // Third subscription @@ -345,7 +359,9 @@ async fn serial_resume() { // This set of messages, in particular, _should not_ include any prior // messages from `initial_messages` or `resume_messages`. for message in &final_messages { - assert!(events.iter().any(|event| event == message)); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } }; } @@ -378,13 +394,17 @@ async fn terminates_on_token_expiry() { // These should not be delivered. let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; assert!(events - .filter(|event| future::ready(messages.contains(event))) + .filter(|event| future::ready( + messages + .iter() + .any(|message| fixtures::event::message_sent(event, message)) + )) .next() .immediately() .await @@ -425,13 +445,17 @@ async fn terminates_on_logout() { // These should not be delivered. let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; assert!(events - .filter(|event| future::ready(messages.contains(event))) + .filter(|event| future::ready( + messages + .iter() + .any(|message| fixtures::event::message_sent(event, message)) + )) .next() .immediately() .await diff --git a/src/event/sequence.rs b/src/event/sequence.rs index 9ebddd7..c566156 100644 --- a/src/event/sequence.rs +++ b/src/event/sequence.rs @@ -1,5 +1,20 @@ use std::fmt; +use crate::clock::DateTime; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Instant { + pub at: DateTime, + #[serde(skip)] + pub sequence: Sequence, +} + +impl From for Sequence { + fn from(instant: Instant) -> Self { + instant.sequence + } +} + #[derive( Clone, Copy, @@ -22,3 +37,47 @@ impl fmt::Display for Sequence { value.fmt(f) } } + +impl Sequence { + pub fn up_to(resume_point: Option) -> impl for<'e> Fn(&'e E) -> bool + where + E: Sequenced, + { + move |event| resume_point.map_or(true, |resume_point| event.sequence() <= resume_point) + } + + pub fn after(resume_point: Option) -> impl for<'e> Fn(&'e E) -> bool + where + E: Sequenced, + { + move |event| resume_point < Some(event.sequence()) + } + + pub fn start_from(resume_point: Self) -> impl for<'e> Fn(&'e E) -> bool + where + E: Sequenced, + { + move |event| resume_point <= event.sequence() + } +} + +pub trait Sequenced { + fn instant(&self) -> Instant; + + fn sequence(&self) -> Sequence { + self.instant().into() + } +} + +impl Sequenced for &E +where + E: Sequenced, +{ + fn instant(&self) -> Instant { + (*self).instant() + } + + fn sequence(&self) -> Sequence { + (*self).sequence() + } +} diff --git a/src/event/types.rs b/src/event/types.rs deleted file mode 100644 index 2324dc1..0000000 --- a/src/event/types.rs +++ /dev/null @@ -1,85 +0,0 @@ -use crate::{ - channel::{self, Channel}, - event::{Instant, Sequence}, - login::Login, - message::{self, Message}, -}; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct ChannelEvent { - #[serde(flatten)] - pub instant: Instant, - #[serde(flatten)] - pub data: ChannelEventData, -} - -impl ChannelEvent { - pub fn created(channel: Channel) -> Self { - Self { - instant: channel.created, - data: CreatedEvent { channel }.into(), - } - } -} - -impl<'c> From<&'c ChannelEvent> for Sequence { - fn from(event: &'c ChannelEvent) -> Self { - event.instant.sequence - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ChannelEventData { - Created(CreatedEvent), - Message(MessageEvent), - MessageDeleted(MessageDeletedEvent), - Deleted(DeletedEvent), -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct CreatedEvent { - pub channel: Channel, -} - -impl From for ChannelEventData { - fn from(event: CreatedEvent) -> Self { - Self::Created(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageEvent { - pub channel: Channel, - pub sender: Login, - pub message: Message, -} - -impl From for ChannelEventData { - fn from(event: MessageEvent) -> Self { - Self::Message(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageDeletedEvent { - pub channel: Channel, - pub message: message::Id, -} - -impl From for ChannelEventData { - fn from(event: MessageDeletedEvent) -> Self { - Self::MessageDeleted(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct DeletedEvent { - pub channel: channel::Id, -} - -impl From for ChannelEventData { - fn from(event: DeletedEvent) -> Self { - Self::Deleted(event) - } -} diff --git a/src/expire.rs b/src/expire.rs index a8eb8ad..e50bcb4 100644 --- a/src/expire.rs +++ b/src/expire.rs @@ -14,7 +14,7 @@ pub async fn middleware( next: Next, ) -> Result { app.tokens().expire(&expired_at).await?; - app.events().expire(&expired_at).await?; + app.messages().expire(&expired_at).await?; app.channels().expire(&expired_at).await?; Ok(next.run(req).await) } diff --git a/src/message/app.rs b/src/message/app.rs new file mode 100644 index 0000000..51f772e --- /dev/null +++ b/src/message/app.rs @@ -0,0 +1,88 @@ +use chrono::TimeDelta; +use itertools::Itertools; +use sqlx::sqlite::SqlitePool; + +use super::{repo::Provider as _, Message}; +use crate::{ + channel::{self, repo::Provider as _}, + clock::DateTime, + db::NotFound as _, + event::{broadcaster::Broadcaster, repo::Provider as _, Sequence}, + login::Login, +}; + +pub struct Messages<'a> { + db: &'a SqlitePool, + events: &'a Broadcaster, +} + +impl<'a> Messages<'a> { + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } + } + + pub async fn send( + &self, + channel: &channel::Id, + sender: &Login, + sent_at: &DateTime, + body: &str, + ) -> Result { + let mut tx = self.db.begin().await?; + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| Error::ChannelNotFound(channel.clone()))?; + let sent = tx.sequence().next(sent_at).await?; + let message = tx + .messages() + .create(&channel.snapshot(), sender, &sent, body) + .await?; + tx.commit().await?; + + for event in message.events() { + self.events.broadcast(event); + } + + Ok(message.snapshot()) + } + + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 90 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(90); + + let mut tx = self.db.begin().await?; + let expired = tx.messages().expired(&expire_at).await?; + + let mut events = Vec::with_capacity(expired.len()); + for (channel, message) in expired { + let deleted = tx.sequence().next(relative_to).await?; + let message = tx.messages().delete(&channel, &message, &deleted).await?; + events.push( + message + .events() + .filter(Sequence::start_from(deleted.sequence)), + ); + } + + tx.commit().await?; + + for event in events + .into_iter() + .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + { + self.events.broadcast(event); + } + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("channel {0} not found")] + ChannelNotFound(channel::Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} diff --git a/src/message/event.rs b/src/message/event.rs new file mode 100644 index 0000000..bcc2238 --- /dev/null +++ b/src/message/event.rs @@ -0,0 +1,50 @@ +use super::{snapshot::Message, Id}; +use crate::{ + channel::Channel, + event::{Instant, Sequenced}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Event { + #[serde(flatten)] + pub instant: Instant, + #[serde(flatten)] + pub kind: Kind, +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + self.instant + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Kind { + Sent(Sent), + Deleted(Deleted), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Sent { + #[serde(flatten)] + pub message: Message, +} + +impl From for Kind { + fn from(event: Sent) -> Self { + Self::Sent(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Deleted { + pub channel: Channel, + pub message: Id, +} + +impl From for Kind { + fn from(event: Deleted) -> Self { + Self::Deleted(event) + } +} diff --git a/src/message/history.rs b/src/message/history.rs new file mode 100644 index 0000000..5aca47e --- /dev/null +++ b/src/message/history.rs @@ -0,0 +1,43 @@ +use super::{ + event::{Deleted, Event, Sent}, + Message, +}; +use crate::event::Instant; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct History { + pub message: Message, + pub sent: Instant, + pub deleted: Option, +} + +impl History { + fn sent(&self) -> Event { + Event { + instant: self.sent, + kind: Sent { + message: self.message.clone(), + } + .into(), + } + } + + fn deleted(&self) -> Option { + self.deleted.map(|instant| Event { + instant, + kind: Deleted { + channel: self.message.channel.clone(), + message: self.message.id.clone(), + } + .into(), + }) + } + + pub fn events(&self) -> impl Iterator { + [self.sent()].into_iter().chain(self.deleted()) + } + + pub fn snapshot(&self) -> Message { + self.message.clone() + } +} diff --git a/src/message/mod.rs b/src/message/mod.rs index 9a9bf14..52d56c1 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -1,9 +1,8 @@ +pub mod app; +pub mod event; +mod history; mod id; +pub mod repo; +mod snapshot; -pub use self::id::Id; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Message { - pub id: Id, - pub body: String, -} +pub use self::{event::Event, history::History, id::Id, snapshot::Message}; diff --git a/src/message/repo.rs b/src/message/repo.rs new file mode 100644 index 0000000..3b2b8f7 --- /dev/null +++ b/src/message/repo.rs @@ -0,0 +1,214 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use super::{snapshot::Message, History, Id}; +use crate::{ + channel::{self, Channel}, + clock::DateTime, + event::{Instant, Sequence}, + login::{self, Login}, +}; + +pub trait Provider { + fn messages(&mut self) -> Messages; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn messages(&mut self) -> Messages { + Messages(self) + } +} + +pub struct Messages<'t>(&'t mut SqliteConnection); + +impl<'c> Messages<'c> { + pub async fn create( + &mut self, + channel: &Channel, + sender: &Login, + sent: &Instant, + body: &str, + ) -> Result { + let id = Id::generate(); + + let message = sqlx::query!( + r#" + insert into message + (id, channel, sender, sent_at, sent_sequence, body) + values ($1, $2, $3, $4, $5, $6) + returning + id as "id: Id", + body + "#, + id, + channel.id, + sender.id, + sent.at, + sent.sequence, + body, + ) + .map(|row| History { + message: Message { + channel: channel.clone(), + sender: sender.clone(), + id: row.id, + body: row.body, + }, + sent: *sent, + deleted: None, + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + async fn by_id(&mut self, channel: &Channel, message: &Id) -> Result { + let message = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + sender.id as "sender_id: login::Id", + sender.name as "sender_name", + message.id as "id: Id", + message.body, + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where message.id = $1 + and message.channel = $2 + "#, + message, + channel.id, + ) + .map(|row| History { + message: Message { + channel: Channel { + id: row.channel_id, + name: row.channel_name, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + id: row.id, + body: row.body, + }, + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, + deleted: None, + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn delete( + &mut self, + channel: &Channel, + message: &Id, + deleted: &Instant, + ) -> Result { + let history = self.by_id(channel, message).await?; + + sqlx::query_scalar!( + r#" + delete from message + where + id = $1 + returning 1 as "deleted: i64" + "#, + history.message.id, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(History { + deleted: Some(*deleted), + ..history + }) + } + + pub async fn expired( + &mut self, + expire_at: &DateTime, + ) -> Result, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + message.id as "message: Id" + from message + join channel on message.channel = channel.id + where sent_at < $1 + "#, + expire_at, + ) + .map(|row| { + ( + Channel { + id: row.channel_id, + name: row.channel_name, + }, + row.message, + ) + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } + + pub async fn replay( + &mut self, + resume_at: Option, + ) -> Result, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + sender.id as "sender_id: login::Id", + sender.name as "sender_name", + message.id as "id: Id", + message.body, + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where coalesce(message.sent_sequence > $1, true) + "#, + resume_at, + ) + .map(|row| History { + message: Message { + channel: Channel { + id: row.channel_id, + name: row.channel_name, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + id: row.id, + body: row.body, + }, + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, + deleted: None, + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } +} diff --git a/src/message/snapshot.rs b/src/message/snapshot.rs new file mode 100644 index 0000000..3adccbe --- /dev/null +++ b/src/message/snapshot.rs @@ -0,0 +1,74 @@ +use super::{ + event::{Event, Kind, Sent}, + Id, +}; +use crate::{channel::Channel, login::Login}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(into = "self::serialize::Message")] +pub struct Message { + pub channel: Channel, + pub sender: Login, + pub id: Id, + pub body: String, +} + +mod serialize { + use crate::{channel::Channel, login::Login, message::Id}; + + #[derive(serde::Serialize)] + pub struct Message { + channel: Channel, + sender: Login, + #[allow(clippy::struct_field_names)] + // Deliberately redundant with the module path; this produces a specific serialization. + message: MessageData, + } + + #[derive(serde::Serialize)] + pub struct MessageData { + id: Id, + body: String, + } + + impl From for Message { + fn from(message: super::Message) -> Self { + Self { + channel: message.channel, + sender: message.sender, + message: MessageData { + id: message.id, + body: message.body, + }, + } + } + } +} + +impl Message { + fn apply(state: Option, event: Event) -> Option { + match (state, event.kind) { + (None, Kind::Sent(event)) => Some(event.into()), + (Some(message), Kind::Deleted(event)) if message.id == event.message => None, + (state, event) => panic!("invalid message event {event:#?} for state {state:#?}"), + } + } +} + +impl FromIterator for Option { + fn from_iter>(events: I) -> Self { + events.into_iter().fold(None, Message::apply) + } +} + +impl From<&Sent> for Message { + fn from(event: &Sent) -> Self { + event.message.clone() + } +} + +impl From for Message { + fn from(event: Sent) -> Self { + event.message + } +} diff --git a/src/test/fixtures/event.rs b/src/test/fixtures/event.rs new file mode 100644 index 0000000..09f0490 --- /dev/null +++ b/src/test/fixtures/event.rs @@ -0,0 +1,11 @@ +use crate::{ + event::{Event, Kind}, + message::Message, +}; + +pub fn message_sent(event: &Event, message: &Message) -> bool { + matches!( + &event.kind, + Kind::MessageSent(event) if message == &event.into() + ) +} diff --git a/src/test/fixtures/filter.rs b/src/test/fixtures/filter.rs index d1939a5..6e62aea 100644 --- a/src/test/fixtures/filter.rs +++ b/src/test/fixtures/filter.rs @@ -1,11 +1,11 @@ use futures::future; -use crate::event::types; +use crate::event::{Event, Kind}; -pub fn messages() -> impl FnMut(&types::ChannelEvent) -> future::Ready { - |event| future::ready(matches!(event.data, types::ChannelEventData::Message(_))) +pub fn messages() -> impl FnMut(&Event) -> future::Ready { + |event| future::ready(matches!(event.kind, Kind::MessageSent(_))) } -pub fn created() -> impl FnMut(&types::ChannelEvent) -> future::Ready { - |event| future::ready(matches!(event.data, types::ChannelEventData::Created(_))) +pub fn created() -> impl FnMut(&Event) -> future::Ready { + |event| future::ready(matches!(event.kind, Kind::ChannelCreated(_))) } diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs index fd50887..381b10b 100644 --- a/src/test/fixtures/message.rs +++ b/src/test/fixtures/message.rs @@ -1,17 +1,12 @@ use faker_rand::lorem::Paragraphs; -use crate::{app::App, channel::Channel, clock::RequestedAt, event::types, login::Login}; +use crate::{app::App, channel::Channel, clock::RequestedAt, login::Login, message::Message}; -pub async fn send( - app: &App, - login: &Login, - channel: &Channel, - sent_at: &RequestedAt, -) -> types::ChannelEvent { +pub async fn send(app: &App, channel: &Channel, login: &Login, sent_at: &RequestedAt) -> Message { let body = propose(); - app.events() - .send(login, &channel.id, &body, sent_at) + app.messages() + .send(&channel.id, login, sent_at, &body) .await .expect("should succeed if the channel exists") } diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index 76467ab..c5efa9b 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -3,6 +3,7 @@ use chrono::{TimeDelta, Utc}; use crate::{app::App, clock::RequestedAt, db}; pub mod channel; +pub mod event; pub mod filter; pub mod future; pub mod identity; diff --git a/src/token/app.rs b/src/token/app.rs index 030ec69..5c4fcd5 100644 --- a/src/token/app.rs +++ b/src/token/app.rs @@ -127,7 +127,7 @@ impl<'a> Tokens<'a> { tx.commit().await?; for event in tokens.into_iter().map(event::TokenRevoked::from) { - self.tokens.broadcast(&event); + self.tokens.broadcast(event); } Ok(()) @@ -139,7 +139,7 @@ impl<'a> Tokens<'a> { tx.commit().await?; self.tokens - .broadcast(&event::TokenRevoked::from(token.clone())); + .broadcast(event::TokenRevoked::from(token.clone())); Ok(()) } -- cgit v1.2.3 From 0a5599c60d20ccc2223779eeba5dc91a95ea0fe5 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Thu, 3 Oct 2024 20:17:07 -0400 Subject: Add endpoints for deleting channels and messages. It is deliberate that the expire() functions do not use them. To avoid races, the transactions must be committed before events get sent, in both cases, which makes them structurally pretty different. --- ...4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json | 20 ++++++ ...0b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json | 32 ---------- ...9cedc6bee1750d28a6176980ed7040b8a3301fc7e5.json | 62 +++++++++++++++++++ ...99e837106c799e84015425286b79f42e4001d8a4c7.json | 62 ------------------- ...ad2d2dec42949522f182a61bfb249f13ee78564179.json | 20 ++++++ docs/api.md | 28 +++++++++ src/channel/app.rs | 72 ++++++++++++++++++---- src/channel/routes.rs | 61 ++++++++++++------ src/channel/routes/test/on_send.rs | 6 +- src/cli.rs | 13 ++-- src/event/app.rs | 1 + src/event/broadcaster.rs | 2 +- src/message/app.rs | 61 +++++++++++++----- src/message/mod.rs | 3 +- src/message/repo.rs | 46 +++++++------- src/message/routes.rs | 46 ++++++++++++++ 16 files changed, 363 insertions(+), 172 deletions(-) create mode 100644 .sqlx/query-46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json delete mode 100644 .sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json create mode 100644 .sqlx/query-6fc4be85527af518da17c49cedc6bee1750d28a6176980ed7040b8a3301fc7e5.json delete mode 100644 .sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json create mode 100644 .sqlx/query-f3a338b9e4a65856decd79ad2d2dec42949522f182a61bfb249f13ee78564179.json create mode 100644 src/message/routes.rs diff --git a/.sqlx/query-46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json b/.sqlx/query-46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json new file mode 100644 index 0000000..ee0f235 --- /dev/null +++ b/.sqlx/query-46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n select\n message.id as \"id: Id\"\n from message\n join channel on message.channel = channel.id\n where channel.id = $1\n order by message.sent_sequence\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false + ] + }, + "hash": "46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80" +} diff --git a/.sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json b/.sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json deleted file mode 100644 index fb5f94b..0000000 --- a/.sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n message.id as \"message: Id\"\n from message\n join channel on message.channel = channel.id\n where sent_at < $1\n ", - "describe": { - "columns": [ - { - "name": "channel_id: channel::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "channel_name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "message: Id", - "ordinal": 2, - "type_info": "Text" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc" -} diff --git a/.sqlx/query-6fc4be85527af518da17c49cedc6bee1750d28a6176980ed7040b8a3301fc7e5.json b/.sqlx/query-6fc4be85527af518da17c49cedc6bee1750d28a6176980ed7040b8a3301fc7e5.json new file mode 100644 index 0000000..257e1f6 --- /dev/null +++ b/.sqlx/query-6fc4be85527af518da17c49cedc6bee1750d28a6176980ed7040b8a3301fc7e5.json @@ -0,0 +1,62 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n sender.id as \"sender_id: login::Id\",\n sender.name as \"sender_name\",\n message.id as \"id: Id\",\n message.body,\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where message.id = $1\n ", + "describe": { + "columns": [ + { + "name": "channel_id: channel::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "sender_id: login::Id", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "sender_name", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "id: Id", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "body", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "sent_at: DateTime", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "sent_sequence: Sequence", + "ordinal": 7, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "6fc4be85527af518da17c49cedc6bee1750d28a6176980ed7040b8a3301fc7e5" +} diff --git a/.sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json b/.sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json deleted file mode 100644 index 288a657..0000000 --- a/.sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json +++ /dev/null @@ -1,62 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n sender.id as \"sender_id: login::Id\",\n sender.name as \"sender_name\",\n message.id as \"id: Id\",\n message.body,\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where message.id = $1\n and message.channel = $2\n ", - "describe": { - "columns": [ - { - "name": "channel_id: channel::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "channel_name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "sender_id: login::Id", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "sender_name", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "id: Id", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "body", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "sent_at: DateTime", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "sent_sequence: Sequence", - "ordinal": 7, - "type_info": "Integer" - } - ], - "parameters": { - "Right": 2 - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false - ] - }, - "hash": "e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7" -} diff --git a/.sqlx/query-f3a338b9e4a65856decd79ad2d2dec42949522f182a61bfb249f13ee78564179.json b/.sqlx/query-f3a338b9e4a65856decd79ad2d2dec42949522f182a61bfb249f13ee78564179.json new file mode 100644 index 0000000..92a64a3 --- /dev/null +++ b/.sqlx/query-f3a338b9e4a65856decd79ad2d2dec42949522f182a61bfb249f13ee78564179.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n select\n id as \"message: Id\"\n from message\n where sent_at < $1\n ", + "describe": { + "columns": [ + { + "name": "message: Id", + "ordinal": 0, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false + ] + }, + "hash": "f3a338b9e4a65856decd79ad2d2dec42949522f182a61bfb249f13ee78564179" +} diff --git a/docs/api.md b/docs/api.md index 5adf28d..ef211bc 100644 --- a/docs/api.md +++ b/docs/api.md @@ -151,6 +151,34 @@ Once the message is accepted, this will return a 202 Accepted response. The mess If the channel ID is not valid, this will return a 404 Not Found response. +### `DELETE /api/channels/:channel` + +Deletes a channel (and all messages in it). + +The `:channel` placeholder must be a channel ID, as returned by `GET /api/channels` or `POST /api/channels`. + +#### On success + +This will return a 202 Accepted response on success, and delete the channel. + +#### Invalid channel ID + +If the channel ID is not valid, this will return a 404 Not Found response. + +### `DELETE /api/messages/:message` + +Deletes a message. + +The `:message` placeholder must be a message ID, as returned from the event stream or from a list of messages. + +#### On success + +This will return a 202 Accepted response on success, and delete the message. + +#### Invalid message ID + +If the message ID is not valid, this will return a 404 Not Found response. + ### `GET /api/events` Subscribes to events. This endpoint returns an `application/event-stream` response, and is intended for use with the `EventSource` browser API. Events will be delivered on this stream as they occur, and the request will remain open to deliver events. diff --git a/src/channel/app.rs b/src/channel/app.rs index 6ce826b..24be2ff 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -2,10 +2,12 @@ use chrono::TimeDelta; use itertools::Itertools; use sqlx::sqlite::SqlitePool; +use super::{repo::Provider as _, Channel, Id}; use crate::{ - channel::{repo::Provider as _, Channel}, clock::DateTime, - event::{broadcaster::Broadcaster, repo::Provider as _, Sequence}, + db::NotFound, + event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence, Sequenced}, + message::repo::Provider as _, }; pub struct Channels<'a> { @@ -28,9 +30,8 @@ impl<'a> Channels<'a> { .map_err(|err| CreateError::from_duplicate_name(err, name))?; tx.commit().await?; - for event in channel.events() { - self.events.broadcast(event); - } + self.events + .broadcast(channel.events().map(Event::from).collect::>()); Ok(channel.snapshot()) } @@ -53,6 +54,46 @@ impl<'a> Channels<'a> { Ok(channels) } + pub async fn delete(&self, channel: &Id, deleted_at: &DateTime) -> Result<(), DeleteError> { + let mut tx = self.db.begin().await?; + + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| DeleteError::NotFound(channel.clone()))? + .snapshot(); + + let mut events = Vec::new(); + + let messages = tx.messages().in_channel(&channel).await?; + for message in messages { + let deleted = tx.sequence().next(deleted_at).await?; + let message = tx.messages().delete(&message, &deleted).await?; + events.extend( + message + .events() + .filter(Sequence::start_from(deleted.sequence)) + .map(Event::from), + ); + } + + let deleted = tx.sequence().next(deleted_at).await?; + let channel = tx.channels().delete(&channel.id, &deleted).await?; + events.extend( + channel + .events() + .filter(Sequence::start_from(deleted.sequence)) + .map(Event::from), + ); + + tx.commit().await?; + + self.events.broadcast(events); + + Ok(()) + } + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { // Somewhat arbitrarily, expire after 90 days. let expire_at = relative_to.to_owned() - TimeDelta::days(90); @@ -73,12 +114,13 @@ impl<'a> Channels<'a> { tx.commit().await?; - for event in events - .into_iter() - .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) - { - self.events.broadcast(event); - } + self.events.broadcast( + events + .into_iter() + .kmerge_by(|a, b| a.sequence() < b.sequence()) + .map(Event::from) + .collect::>(), + ); Ok(()) } @@ -92,6 +134,14 @@ pub enum CreateError { DatabaseError(#[from] sqlx::Error), } +#[derive(Debug, thiserror::Error)] +pub enum DeleteError { + #[error("channel {0} not found")] + NotFound(Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} + impl CreateError { fn from_duplicate_name(error: sqlx::Error, name: &str) -> Self { if let Some(error) = error.as_database_error() { diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 5bb1ee9..bce634e 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -2,20 +2,18 @@ use axum::{ extract::{Json, Path, State}, http::StatusCode, response::{IntoResponse, Response}, - routing::{get, post}, + routing::{delete, get, post}, Router, }; use axum_extra::extract::Query; -use super::app; +use super::{ + app::{self, DeleteError}, + Channel, Id, +}; use crate::{ - app::App, - channel::{self, Channel}, - clock::RequestedAt, - error::Internal, - event::Sequence, - login::Login, - message::app::Error as MessageError, + app::App, clock::RequestedAt, error::Internal, event::Sequence, login::Login, + message::app::SendError, }; #[cfg(test)] @@ -26,6 +24,7 @@ pub fn router() -> Router { .route("/api/channels", get(list)) .route("/api/channels", post(on_create)) .route("/api/channels/:channel", post(on_send)) + .route("/api/channels/:channel", delete(on_delete)) } #[derive(Default, serde::Deserialize)] @@ -95,28 +94,54 @@ struct SendRequest { async fn on_send( State(app): State, - Path(channel): Path, + Path(channel): Path, RequestedAt(sent_at): RequestedAt, login: Login, Json(request): Json, -) -> Result { +) -> Result { app.messages() .send(&channel, &login, &sent_at, &request.message) - .await - // Could impl `From` here, but it's more code and this is used once. - .map_err(ErrorResponse)?; + .await?; Ok(StatusCode::ACCEPTED) } -#[derive(Debug)] -struct ErrorResponse(MessageError); +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +struct SendErrorResponse(#[from] SendError); + +impl IntoResponse for SendErrorResponse { + fn into_response(self) -> Response { + let Self(error) = self; + match error { + not_found @ SendError::ChannelNotFound(_) => { + (StatusCode::NOT_FOUND, not_found.to_string()).into_response() + } + other => Internal::from(other).into_response(), + } + } +} + +async fn on_delete( + State(app): State, + Path(channel): Path, + RequestedAt(deleted_at): RequestedAt, + _: Login, +) -> Result { + app.channels().delete(&channel, &deleted_at).await?; + + Ok(StatusCode::ACCEPTED) +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +struct DeleteErrorResponse(#[from] DeleteError); -impl IntoResponse for ErrorResponse { +impl IntoResponse for DeleteErrorResponse { fn into_response(self) -> Response { let Self(error) = self; match error { - not_found @ MessageError::ChannelNotFound(_) => { + not_found @ DeleteError::NotFound(_) => { (StatusCode::NOT_FOUND, not_found.to_string()).into_response() } other => Internal::from(other).into_response(), diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 1027b29..3297093 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -5,7 +5,7 @@ use crate::{ channel, channel::routes, event, - message::app, + message::app::SendError, test::fixtures::{self, future::Immediately as _}, }; @@ -77,7 +77,7 @@ async fn nonexistent_channel() { let request = routes::SendRequest { message: fixtures::message::propose(), }; - let routes::ErrorResponse(error) = routes::on_send( + let routes::SendErrorResponse(error) = routes::on_send( State(app), Path(channel.clone()), sent_at, @@ -91,6 +91,6 @@ async fn nonexistent_channel() { assert!(matches!( error, - app::Error::ChannelNotFound(error_channel) if channel == error_channel + SendError::ChannelNotFound(error_channel) if channel == error_channel )); } diff --git a/src/cli.rs b/src/cli.rs index 893fae2..2d9f512 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -10,7 +10,7 @@ use clap::Parser; use sqlx::sqlite::SqlitePool; use tokio::net; -use crate::{app::App, channel, clock, db, event, expire, login}; +use crate::{app::App, channel, clock, db, event, expire, login, message}; /// Command-line entry point for running the `hi` server. /// @@ -105,9 +105,14 @@ impl Args { } fn routers() -> Router { - [channel::router(), event::router(), login::router()] - .into_iter() - .fold(Router::default(), Router::merge) + [ + channel::router(), + event::router(), + login::router(), + message::router(), + ] + .into_iter() + .fold(Router::default(), Router::merge) } fn started_msg(listener: &net::TcpListener) -> io::Result { diff --git a/src/event/app.rs b/src/event/app.rs index e58bea9..32f0a97 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -61,6 +61,7 @@ impl<'a> Events<'a> { // Filtering on the broadcast resume point filters out messages // before resume_at, and filters out messages duplicated from // `replay_events`. + .flat_map(stream::iter) .filter(Self::resume(resume_live_at)); Ok(replay.chain(live_messages)) diff --git a/src/event/broadcaster.rs b/src/event/broadcaster.rs index de2513a..3c4efac 100644 --- a/src/event/broadcaster.rs +++ b/src/event/broadcaster.rs @@ -1,3 +1,3 @@ use crate::broadcast; -pub type Broadcaster = broadcast::Broadcaster; +pub type Broadcaster = broadcast::Broadcaster>; diff --git a/src/message/app.rs b/src/message/app.rs index 51f772e..1d34c14 100644 --- a/src/message/app.rs +++ b/src/message/app.rs @@ -2,12 +2,12 @@ use chrono::TimeDelta; use itertools::Itertools; use sqlx::sqlite::SqlitePool; -use super::{repo::Provider as _, Message}; +use super::{repo::Provider as _, Id, Message}; use crate::{ channel::{self, repo::Provider as _}, clock::DateTime, db::NotFound as _, - event::{broadcaster::Broadcaster, repo::Provider as _, Sequence}, + event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence}, login::Login, }; @@ -27,13 +27,13 @@ impl<'a> Messages<'a> { sender: &Login, sent_at: &DateTime, body: &str, - ) -> Result { + ) -> Result { let mut tx = self.db.begin().await?; let channel = tx .channels() .by_id(channel) .await - .not_found(|| Error::ChannelNotFound(channel.clone()))?; + .not_found(|| SendError::ChannelNotFound(channel.clone()))?; let sent = tx.sequence().next(sent_at).await?; let message = tx .messages() @@ -41,24 +41,40 @@ impl<'a> Messages<'a> { .await?; tx.commit().await?; - for event in message.events() { - self.events.broadcast(event); - } + self.events + .broadcast(message.events().map(Event::from).collect::>()); Ok(message.snapshot()) } + pub async fn delete(&self, message: &Id, deleted_at: &DateTime) -> Result<(), DeleteError> { + let mut tx = self.db.begin().await?; + let deleted = tx.sequence().next(deleted_at).await?; + let message = tx.messages().delete(message, &deleted).await?; + tx.commit().await?; + + self.events.broadcast( + message + .events() + .filter(Sequence::start_from(deleted.sequence)) + .map(Event::from) + .collect::>(), + ); + + Ok(()) + } + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { // Somewhat arbitrarily, expire after 90 days. let expire_at = relative_to.to_owned() - TimeDelta::days(90); let mut tx = self.db.begin().await?; - let expired = tx.messages().expired(&expire_at).await?; + let expired = tx.messages().expired(&expire_at).await?; let mut events = Vec::with_capacity(expired.len()); - for (channel, message) in expired { + for message in expired { let deleted = tx.sequence().next(relative_to).await?; - let message = tx.messages().delete(&channel, &message, &deleted).await?; + let message = tx.messages().delete(&message, &deleted).await?; events.push( message .events() @@ -68,21 +84,32 @@ impl<'a> Messages<'a> { tx.commit().await?; - for event in events - .into_iter() - .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) - { - self.events.broadcast(event); - } + self.events.broadcast( + events + .into_iter() + .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + .map(Event::from) + .collect::>(), + ); Ok(()) } } #[derive(Debug, thiserror::Error)] -pub enum Error { +pub enum SendError { + #[error("channel {0} not found")] + ChannelNotFound(channel::Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum DeleteError { #[error("channel {0} not found")] ChannelNotFound(channel::Id), + #[error("message {0} not found")] + NotFound(Id), #[error(transparent)] DatabaseError(#[from] sqlx::Error), } diff --git a/src/message/mod.rs b/src/message/mod.rs index 52d56c1..a8f51ab 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -3,6 +3,7 @@ pub mod event; mod history; mod id; pub mod repo; +mod routes; mod snapshot; -pub use self::{event::Event, history::History, id::Id, snapshot::Message}; +pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Message}; diff --git a/src/message/repo.rs b/src/message/repo.rs index 3b2b8f7..ae41736 100644 --- a/src/message/repo.rs +++ b/src/message/repo.rs @@ -62,7 +62,25 @@ impl<'c> Messages<'c> { Ok(message) } - async fn by_id(&mut self, channel: &Channel, message: &Id) -> Result { + pub async fn in_channel(&mut self, channel: &Channel) -> Result, sqlx::Error> { + let messages = sqlx::query_scalar!( + r#" + select + message.id as "id: Id" + from message + join channel on message.channel = channel.id + where channel.id = $1 + order by message.sent_sequence + "#, + channel.id, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } + + async fn by_id(&mut self, message: &Id) -> Result { let message = sqlx::query!( r#" select @@ -78,10 +96,8 @@ impl<'c> Messages<'c> { join channel on message.channel = channel.id join login as sender on message.sender = sender.id where message.id = $1 - and message.channel = $2 "#, message, - channel.id, ) .map(|row| History { message: Message { @@ -110,11 +126,10 @@ impl<'c> Messages<'c> { pub async fn delete( &mut self, - channel: &Channel, message: &Id, deleted: &Instant, ) -> Result { - let history = self.by_id(channel, message).await?; + let history = self.by_id(message).await?; sqlx::query_scalar!( r#" @@ -134,31 +149,16 @@ impl<'c> Messages<'c> { }) } - pub async fn expired( - &mut self, - expire_at: &DateTime, - ) -> Result, sqlx::Error> { - let messages = sqlx::query!( + pub async fn expired(&mut self, expire_at: &DateTime) -> Result, sqlx::Error> { + let messages = sqlx::query_scalar!( r#" select - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - message.id as "message: Id" + id as "message: Id" from message - join channel on message.channel = channel.id where sent_at < $1 "#, expire_at, ) - .map(|row| { - ( - Channel { - id: row.channel_id, - name: row.channel_name, - }, - row.message, - ) - }) .fetch_all(&mut *self.0) .await?; diff --git a/src/message/routes.rs b/src/message/routes.rs new file mode 100644 index 0000000..29fe3d7 --- /dev/null +++ b/src/message/routes.rs @@ -0,0 +1,46 @@ +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::{IntoResponse, Response}, + routing::delete, + Router, +}; + +use crate::{ + app::App, + clock::RequestedAt, + error::Internal, + login::Login, + message::{self, app::DeleteError}, +}; + +pub fn router() -> Router { + Router::new().route("/api/messages/:message", delete(on_delete)) +} + +async fn on_delete( + State(app): State, + Path(message): Path, + RequestedAt(deleted_at): RequestedAt, + _: Login, +) -> Result { + app.messages().delete(&message, &deleted_at).await?; + + Ok(StatusCode::ACCEPTED) +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +struct ErrorResponse(#[from] DeleteError); + +impl IntoResponse for ErrorResponse { + fn into_response(self) -> Response { + let Self(error) = self; + match error { + not_found @ (DeleteError::ChannelNotFound(_) | DeleteError::NotFound(_)) => { + (StatusCode::NOT_FOUND, not_found.to_string()).into_response() + } + other => Internal::from(other).into_response(), + } + } +} -- cgit v1.2.3 From 617172576b95bbb935a75f98a98787da5a4e9a9d Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Thu, 3 Oct 2024 20:44:07 -0400 Subject: List messages per channel. --- ...4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json | 20 ------ ...4a2137be57b3a45fe38a675262ceaaebb3d346a9ca.json | 62 ++++++++++++++++++ docs/api.md | 30 +++++++++ src/channel/app.rs | 44 +++++++++++-- src/channel/repo.rs | 7 +- src/channel/routes.rs | 76 ++++++++++++++++++---- src/event/app.rs | 8 +-- src/event/mod.rs | 2 +- src/event/sequence.rs | 7 ++ src/message/app.rs | 4 +- src/message/event.rs | 27 +++++++- src/message/history.rs | 4 +- src/message/repo.rs | 57 ++++++++++++---- src/message/snapshot.rs | 4 +- 14 files changed, 281 insertions(+), 71 deletions(-) delete mode 100644 .sqlx/query-46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json create mode 100644 .sqlx/query-9606853f2ea9f776f7e4384a2137be57b3a45fe38a675262ceaaebb3d346a9ca.json diff --git a/.sqlx/query-46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json b/.sqlx/query-46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json deleted file mode 100644 index ee0f235..0000000 --- a/.sqlx/query-46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n message.id as \"id: Id\"\n from message\n join channel on message.channel = channel.id\n where channel.id = $1\n order by message.sent_sequence\n ", - "describe": { - "columns": [ - { - "name": "id: Id", - "ordinal": 0, - "type_info": "Text" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false - ] - }, - "hash": "46403b84bfc79a53aec36b4a808afb115f6e47d545dfbeb18f9c54e6eb15eb80" -} diff --git a/.sqlx/query-9606853f2ea9f776f7e4384a2137be57b3a45fe38a675262ceaaebb3d346a9ca.json b/.sqlx/query-9606853f2ea9f776f7e4384a2137be57b3a45fe38a675262ceaaebb3d346a9ca.json new file mode 100644 index 0000000..82246ac --- /dev/null +++ b/.sqlx/query-9606853f2ea9f776f7e4384a2137be57b3a45fe38a675262ceaaebb3d346a9ca.json @@ -0,0 +1,62 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n sender.id as \"sender_id: login::Id\",\n sender.name as \"sender_name\",\n message.id as \"id: Id\",\n message.body,\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where channel.id = $1\n and coalesce(message.sent_sequence <= $2, true)\n order by message.sent_sequence\n ", + "describe": { + "columns": [ + { + "name": "channel_id: channel::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "sender_id: login::Id", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "sender_name", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "id: Id", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "body", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "sent_at: DateTime", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "sent_sequence: Sequence", + "ordinal": 7, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 2 + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "9606853f2ea9f776f7e4384a2137be57b3a45fe38a675262ceaaebb3d346a9ca" +} diff --git a/docs/api.md b/docs/api.md index ef211bc..e8c8c8c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -127,6 +127,36 @@ Channel names must be unique. If a channel with the same name already exists, th The API delivers events to clients to update them on other clients' actions and messages. While there is no specific delivery deadline, messages are delivered as soon as possible on a best-effort basis, and the event system allows clients to replay events or resume interrupted streams, to allow recovery if a message is lost. +### `GET /api/channels/:channel/messages` + +Retrieves historical messages in a channel. + +The `:channel` placeholder must be a channel ID, as returned by `GET /api/channels` or `POST /api/channels`. + +#### Query parameters + +This endpoint accepts an optional `resume_point` query parameter. If provided, the value must be the value obtained from the `/api/boot` method. This parameter will restrict the returned list to messages as they existed at a fixed point in time, with any later changes only appearing in the event stream. + +#### On success + +Responds with a list of message objects, one per message: + +```json +[ + { + "at": "2024-09-27T23:19:10.208147Z", + "sender": { + "id": "L1234abcd", + "name": "example username" + }, + "message": { + "id": "M1312acab", + "body": "beep" + } + } +] +``` + ### `POST /api/channels/:channel` Sends a chat message to a channel. It will be relayed to clients subscribed to the channel's events, and recorded for replay. diff --git a/src/channel/app.rs b/src/channel/app.rs index 24be2ff..b3bfbee 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -7,7 +7,7 @@ use crate::{ clock::DateTime, db::NotFound, event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence, Sequenced}, - message::repo::Provider as _, + message::{repo::Provider as _, Message}, }; pub struct Channels<'a> { @@ -54,22 +54,52 @@ impl<'a> Channels<'a> { Ok(channels) } - pub async fn delete(&self, channel: &Id, deleted_at: &DateTime) -> Result<(), DeleteError> { + pub async fn messages( + &self, + channel: &Id, + resume_point: Option, + ) -> Result, Error> { + let mut tx = self.db.begin().await?; + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| Error::NotFound(channel.clone()))? + .snapshot(); + + let messages = tx + .messages() + .in_channel(&channel, resume_point) + .await? + .into_iter() + .filter_map(|message| { + message + .events() + .filter(Sequence::up_to(resume_point)) + .collect() + }) + .collect(); + + Ok(messages) + } + + pub async fn delete(&self, channel: &Id, deleted_at: &DateTime) -> Result<(), Error> { let mut tx = self.db.begin().await?; let channel = tx .channels() .by_id(channel) .await - .not_found(|| DeleteError::NotFound(channel.clone()))? + .not_found(|| Error::NotFound(channel.clone()))? .snapshot(); let mut events = Vec::new(); - let messages = tx.messages().in_channel(&channel).await?; + let messages = tx.messages().in_channel(&channel, None).await?; for message in messages { + let message = message.snapshot(); let deleted = tx.sequence().next(deleted_at).await?; - let message = tx.messages().delete(&message, &deleted).await?; + let message = tx.messages().delete(&message.id, &deleted).await?; events.extend( message .events() @@ -117,7 +147,7 @@ impl<'a> Channels<'a> { self.events.broadcast( events .into_iter() - .kmerge_by(|a, b| a.sequence() < b.sequence()) + .kmerge_by(Sequence::merge) .map(Event::from) .collect::>(), ); @@ -135,7 +165,7 @@ pub enum CreateError { } #[derive(Debug, thiserror::Error)] -pub enum DeleteError { +pub enum Error { #[error("channel {0} not found")] NotFound(Id), #[error(transparent)] diff --git a/src/channel/repo.rs b/src/channel/repo.rs index 8bb761b..2b48436 100644 --- a/src/channel/repo.rs +++ b/src/channel/repo.rs @@ -84,10 +84,7 @@ impl<'c> Channels<'c> { Ok(channel) } - pub async fn all( - &mut self, - resume_point: Option, - ) -> Result, sqlx::Error> { + pub async fn all(&mut self, resume_at: Option) -> Result, sqlx::Error> { let channels = sqlx::query!( r#" select @@ -99,7 +96,7 @@ impl<'c> Channels<'c> { where coalesce(created_sequence <= $1, true) order by channel.name "#, - resume_point, + resume_at, ) .map(|row| History { channel: Channel { diff --git a/src/channel/routes.rs b/src/channel/routes.rs index bce634e..23c0602 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -7,13 +7,14 @@ use axum::{ }; use axum_extra::extract::Query; -use super::{ - app::{self, DeleteError}, - Channel, Id, -}; +use super::{app, Channel, Id}; use crate::{ - app::App, clock::RequestedAt, error::Internal, event::Sequence, login::Login, - message::app::SendError, + app::App, + clock::RequestedAt, + error::Internal, + event::{Instant, Sequence}, + login::Login, + message::{self, app::SendError}, }; #[cfg(test)] @@ -25,17 +26,18 @@ pub fn router() -> Router { .route("/api/channels", post(on_create)) .route("/api/channels/:channel", post(on_send)) .route("/api/channels/:channel", delete(on_delete)) + .route("/api/channels/:channel/messages", get(messages)) } #[derive(Default, serde::Deserialize)] -struct ListQuery { +struct ResumeQuery { resume_point: Option, } async fn list( State(app): State, _: Login, - Query(query): Query, + Query(query): Query, ) -> Result { let channels = app.channels().all(query.resume_point).await?; let response = Channels(channels); @@ -127,7 +129,7 @@ async fn on_delete( Path(channel): Path, RequestedAt(deleted_at): RequestedAt, _: Login, -) -> Result { +) -> Result { app.channels().delete(&channel, &deleted_at).await?; Ok(StatusCode::ACCEPTED) @@ -135,16 +137,66 @@ async fn on_delete( #[derive(Debug, thiserror::Error)] #[error(transparent)] -struct DeleteErrorResponse(#[from] DeleteError); +struct ErrorResponse(#[from] app::Error); -impl IntoResponse for DeleteErrorResponse { +impl IntoResponse for ErrorResponse { fn into_response(self) -> Response { let Self(error) = self; match error { - not_found @ DeleteError::NotFound(_) => { + not_found @ app::Error::NotFound(_) => { (StatusCode::NOT_FOUND, not_found.to_string()).into_response() } other => Internal::from(other).into_response(), } } } + +async fn messages( + State(app): State, + Path(channel): Path, + _: Login, + Query(query): Query, +) -> Result { + let messages = app + .channels() + .messages(&channel, query.resume_point) + .await?; + let response = Messages( + messages + .into_iter() + .map(|message| MessageView { + sent: message.sent, + sender: message.sender, + message: MessageInner { + id: message.id, + body: message.body, + }, + }) + .collect(), + ); + + Ok(response) +} + +struct Messages(Vec); + +#[derive(serde::Serialize)] +struct MessageView { + #[serde(flatten)] + sent: Instant, + sender: Login, + message: MessageInner, +} + +#[derive(serde::Serialize)] +struct MessageInner { + id: message::Id, + body: String, +} + +impl IntoResponse for Messages { + fn into_response(self) -> Response { + let Self(messages) = self; + Json(messages).into_response() + } +} diff --git a/src/event/app.rs b/src/event/app.rs index 32f0a97..d664ec7 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -36,7 +36,7 @@ impl<'a> Events<'a> { let channel_events = channels .iter() .map(channel::History::events) - .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + .kmerge_by(Sequence::merge) .filter(Sequence::after(resume_at)) .map(Event::from); @@ -44,14 +44,12 @@ impl<'a> Events<'a> { let message_events = messages .iter() .map(message::History::events) - .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + .kmerge_by(Sequence::merge) .filter(Sequence::after(resume_at)) .map(Event::from); let replay_events = channel_events - .merge_by(message_events, |a, b| { - a.instant.sequence < b.instant.sequence - }) + .merge_by(message_events, Sequence::merge) .collect::>(); let resume_live_at = replay_events.last().map(Sequenced::sequence); diff --git a/src/event/mod.rs b/src/event/mod.rs index 1503b77..1349fe6 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -38,7 +38,7 @@ impl From for Event { impl From for Event { fn from(event: message::Event) -> Self { Self { - instant: event.instant, + instant: event.instant(), kind: event.kind.into(), } } diff --git a/src/event/sequence.rs b/src/event/sequence.rs index c566156..fbe3711 100644 --- a/src/event/sequence.rs +++ b/src/event/sequence.rs @@ -59,6 +59,13 @@ impl Sequence { { move |event| resume_point <= event.sequence() } + + pub fn merge(a: &E, b: &E) -> bool + where + E: Sequenced, + { + a.sequence() < b.sequence() + } } pub trait Sequenced { diff --git a/src/message/app.rs b/src/message/app.rs index 1d34c14..fd6a334 100644 --- a/src/message/app.rs +++ b/src/message/app.rs @@ -7,7 +7,7 @@ use crate::{ channel::{self, repo::Provider as _}, clock::DateTime, db::NotFound as _, - event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence}, + event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence, Sequenced}, login::Login, }; @@ -87,7 +87,7 @@ impl<'a> Messages<'a> { self.events.broadcast( events .into_iter() - .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + .kmerge_by(Sequence::merge) .map(Event::from) .collect::>(), ); diff --git a/src/message/event.rs b/src/message/event.rs index bcc2238..66db9b0 100644 --- a/src/message/event.rs +++ b/src/message/event.rs @@ -6,15 +6,13 @@ use crate::{ #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Event { - #[serde(flatten)] - pub instant: Instant, #[serde(flatten)] pub kind: Kind, } impl Sequenced for Event { fn instant(&self) -> Instant { - self.instant + self.kind.instant() } } @@ -25,12 +23,27 @@ pub enum Kind { Deleted(Deleted), } +impl Sequenced for Kind { + fn instant(&self) -> Instant { + match self { + Self::Sent(sent) => sent.instant(), + Self::Deleted(deleted) => deleted.instant(), + } + } +} + #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Sent { #[serde(flatten)] pub message: Message, } +impl Sequenced for Sent { + fn instant(&self) -> Instant { + self.message.sent + } +} + impl From for Kind { fn from(event: Sent) -> Self { Self::Sent(event) @@ -39,10 +52,18 @@ impl From for Kind { #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Deleted { + #[serde(flatten)] + pub instant: Instant, pub channel: Channel, pub message: Id, } +impl Sequenced for Deleted { + fn instant(&self) -> Instant { + self.instant + } +} + impl From for Kind { fn from(event: Deleted) -> Self { Self::Deleted(event) diff --git a/src/message/history.rs b/src/message/history.rs index 5aca47e..89fc6b1 100644 --- a/src/message/history.rs +++ b/src/message/history.rs @@ -7,14 +7,12 @@ use crate::event::Instant; #[derive(Clone, Debug, Eq, PartialEq)] pub struct History { pub message: Message, - pub sent: Instant, pub deleted: Option, } impl History { fn sent(&self) -> Event { Event { - instant: self.sent, kind: Sent { message: self.message.clone(), } @@ -24,8 +22,8 @@ impl History { fn deleted(&self) -> Option { self.deleted.map(|instant| Event { - instant, kind: Deleted { + instant, channel: self.message.channel.clone(), message: self.message.id.clone(), } diff --git a/src/message/repo.rs b/src/message/repo.rs index ae41736..fc835c8 100644 --- a/src/message/repo.rs +++ b/src/message/repo.rs @@ -48,12 +48,12 @@ impl<'c> Messages<'c> { ) .map(|row| History { message: Message { + sent: *sent, channel: channel.clone(), sender: sender.clone(), id: row.id, body: row.body, }, - sent: *sent, deleted: None, }) .fetch_one(&mut *self.0) @@ -62,18 +62,51 @@ impl<'c> Messages<'c> { Ok(message) } - pub async fn in_channel(&mut self, channel: &Channel) -> Result, sqlx::Error> { - let messages = sqlx::query_scalar!( + pub async fn in_channel( + &mut self, + channel: &Channel, + resume_at: Option, + ) -> Result, sqlx::Error> { + let messages = sqlx::query!( r#" select - message.id as "id: Id" + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + sender.id as "sender_id: login::Id", + sender.name as "sender_name", + message.id as "id: Id", + message.body, + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence" from message join channel on message.channel = channel.id + join login as sender on message.sender = sender.id where channel.id = $1 + and coalesce(message.sent_sequence <= $2, true) order by message.sent_sequence "#, channel.id, + resume_at, ) + .map(|row| History { + message: Message { + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, + channel: Channel { + id: row.channel_id, + name: row.channel_name, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + id: row.id, + body: row.body, + }, + deleted: None, + }) .fetch_all(&mut *self.0) .await?; @@ -101,6 +134,10 @@ impl<'c> Messages<'c> { ) .map(|row| History { message: Message { + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, channel: Channel { id: row.channel_id, name: row.channel_name, @@ -112,10 +149,6 @@ impl<'c> Messages<'c> { id: row.id, body: row.body, }, - sent: Instant { - at: row.sent_at, - sequence: row.sent_sequence, - }, deleted: None, }) .fetch_one(&mut *self.0) @@ -189,6 +222,10 @@ impl<'c> Messages<'c> { ) .map(|row| History { message: Message { + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, channel: Channel { id: row.channel_id, name: row.channel_name, @@ -200,10 +237,6 @@ impl<'c> Messages<'c> { id: row.id, body: row.body, }, - sent: Instant { - at: row.sent_at, - sequence: row.sent_sequence, - }, deleted: None, }) .fetch_all(&mut *self.0) diff --git a/src/message/snapshot.rs b/src/message/snapshot.rs index 3adccbe..522c1aa 100644 --- a/src/message/snapshot.rs +++ b/src/message/snapshot.rs @@ -2,11 +2,13 @@ use super::{ event::{Event, Kind, Sent}, Id, }; -use crate::{channel::Channel, login::Login}; +use crate::{channel::Channel, event::Instant, login::Login}; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] #[serde(into = "self::serialize::Message")] pub struct Message { + #[serde(skip)] + pub sent: Instant, pub channel: Channel, pub sender: Login, pub id: Id, -- cgit v1.2.3 From 7f12fd41c2941a55a6437f24e4f780104a718790 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Thu, 3 Oct 2024 21:09:26 -0400 Subject: Stray warnings --- src/channel/app.rs | 2 +- src/message/app.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/channel/app.rs b/src/channel/app.rs index b3bfbee..bb331ec 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -6,7 +6,7 @@ use super::{repo::Provider as _, Channel, Id}; use crate::{ clock::DateTime, db::NotFound, - event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence, Sequenced}, + event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence}, message::{repo::Provider as _, Message}, }; diff --git a/src/message/app.rs b/src/message/app.rs index fd6a334..33ea8ad 100644 --- a/src/message/app.rs +++ b/src/message/app.rs @@ -7,7 +7,7 @@ use crate::{ channel::{self, repo::Provider as _}, clock::DateTime, db::NotFound as _, - event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence, Sequenced}, + event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence}, login::Login, }; -- cgit v1.2.3