diff options
| -rw-r--r-- | .sqlx/query-99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746.json | 32 | ||||
| -rw-r--r-- | Cargo.toml | 1 | ||||
| -rw-r--r-- | src/app.rs | 12 | ||||
| -rw-r--r-- | src/conversation/handlers/send/mod.rs | 10 | ||||
| -rw-r--r-- | src/expire.rs | 5 | ||||
| -rw-r--r-- | src/message/app.rs | 178 | ||||
| -rw-r--r-- | src/message/handlers/delete/mod.rs | 4 | ||||
| -rw-r--r-- | src/push/app.rs | 2 | ||||
| -rw-r--r-- | src/push/handlers/ping/test.rs | 24 | ||||
| -rw-r--r-- | src/push/publisher.rs | 20 | ||||
| -rw-r--r-- | src/push/repo.rs | 21 | ||||
| -rw-r--r-- | src/test/fixtures/message.rs | 6 | ||||
| -rw-r--r-- | src/test/webpush.rs | 55 |
13 files changed, 253 insertions, 117 deletions
diff --git a/.sqlx/query-99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746.json b/.sqlx/query-99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746.json new file mode 100644 index 0000000..4bf6146 --- /dev/null +++ b/.sqlx/query-99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746.json @@ -0,0 +1,32 @@ +{ + "db_name": "SQLite", + "query": "\n select\n sub.endpoint,\n sub.p256dh,\n sub.auth\n from push_subscription as sub\n join token on sub.token = token.id\n where token.login <> $1\n ", + "describe": { + "columns": [ + { + "name": "endpoint", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "p256dh", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "auth", + "ordinal": 2, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746" +} @@ -22,6 +22,7 @@ assets = [ [dependencies] argon2 = "0.5.3" +async-trait = "0.1.89" axum = { version = "0.8.6", features = ["form"] } axum-extra = { version = "0.12.1", features = ["cookie", "query", "typed-header"] } base64 = "0.22.1" @@ -58,8 +58,11 @@ impl<P> App<P> { Logins::new(self.db.clone(), self.token_events.clone()) } - pub fn messages(&self) -> Messages { - Messages::new(self.db.clone(), self.events.clone()) + pub fn messages(&self) -> Messages<P> + where + P: Clone, + { + Messages::new(self.db.clone(), self.events.clone(), self.publisher.clone()) } #[cfg(test)] @@ -116,7 +119,10 @@ impl<P> FromRef<App<P>> for Logins { } } -impl<P> FromRef<App<P>> for Messages { +impl<P> FromRef<App<P>> for Messages<P> +where + P: Clone, +{ fn from_ref(app: &App<P>) -> Self { app.messages() } diff --git a/src/conversation/handlers/send/mod.rs b/src/conversation/handlers/send/mod.rs index 979dd24..67d3a70 100644 --- a/src/conversation/handlers/send/mod.rs +++ b/src/conversation/handlers/send/mod.rs @@ -12,19 +12,23 @@ use crate::{ Body, Message, app::{Messages, SendError}, }, + push::Publish, token::extract::Identity, }; #[cfg(test)] mod test; -pub async fn handler( - State(messages): State<Messages>, +pub async fn handler<P>( + State(messages): State<Messages<P>>, Path(conversation): Path<PathInfo>, RequestedAt(sent_at): RequestedAt, identity: Identity, Json(request): Json<Request>, -) -> Result<Response, Error> { +) -> Result<Response, Error> +where + P: Publish, +{ let message = messages .send(&conversation, &identity.login, &sent_at, &request.body) .await?; diff --git a/src/expire.rs b/src/expire.rs index c3b0117..da7ba53 100644 --- a/src/expire.rs +++ b/src/expire.rs @@ -12,7 +12,10 @@ pub async fn middleware<P>( RequestedAt(expired_at): RequestedAt, req: Request, next: Next, -) -> Result<Response, Internal> { +) -> Result<Response, Internal> +where + P: Clone, +{ app.tokens().expire(&expired_at).await?; app.invites().expire(&expired_at).await?; app.messages().expire(&expired_at).await?; diff --git a/src/message/app.rs b/src/message/app.rs index b82fa83..8200650 100644 --- a/src/message/app.rs +++ b/src/message/app.rs @@ -1,6 +1,7 @@ use chrono::TimeDelta; use itertools::Itertools; use sqlx::sqlite::SqlitePool; +use web_push::WebPushError; use super::{Body, History, Id, Message, history, repo::Provider as _}; use crate::{ @@ -11,72 +12,24 @@ use crate::{ error::failed::{Failed, ResultExt as _}, event::{Broadcaster, Sequence, repo::Provider as _}, login::Login, + push::{Publish, repo::Provider as _}, user::{self, repo::Provider as _}, + vapid::repo::Provider as _, }; -pub struct Messages { +pub struct Messages<P> { db: SqlitePool, events: Broadcaster, + publisher: P, } -impl Messages { - pub const fn new(db: SqlitePool, events: Broadcaster) -> Self { - Self { db, events } - } - - pub async fn send( - &self, - conversation: &conversation::Id, - sender: &Login, - sent_at: &DateTime, - body: &Body, - ) -> Result<Message, SendError> { - let conversation_not_found = || SendError::ConversationNotFound(conversation.clone()); - let conversation_deleted = || SendError::ConversationDeleted(conversation.clone()); - let sender_not_found = || SendError::SenderNotFound(sender.id.clone().into()); - let sender_deleted = || SendError::SenderDeleted(sender.id.clone().into()); - - let mut tx = self.db.begin().await.fail(db::failed::BEGIN)?; - let conversation = tx - .conversations() - .by_id(conversation) - .await - .optional() - .fail("Failed to load conversation")? - .ok_or_else(conversation_not_found)?; - let sender = tx - .users() - .by_login(sender) - .await - .optional() - .fail("Failed to load sending user")? - .ok_or_else(sender_not_found)?; - - // Ordering: don't bother allocating a sequence number before we know the channel might - // exist. - let sent = tx - .sequence() - .next(sent_at) - .await - .fail("Failed to find event sequence number")?; - let conversation = conversation.as_of(sent).ok_or_else(conversation_deleted)?; - let sender = sender.as_of(sent).ok_or_else(sender_deleted)?; - let message = History::begin(&conversation, &sender, body, sent); - - // This filter technically includes every event in the history, but it's easier to follow if - // the various event-manipulating app methods are consistent, and it's harmless to have an - // always-satisfied filter. - let events = message.events().filter(Sequence::start_from(sent)); - tx.messages() - .record_events(events.clone()) - .await - .fail("Failed to store events")?; - - tx.commit().await.fail(db::failed::COMMIT)?; - - self.events.broadcast_from(events); - - Ok(message.as_sent()) +impl<P> Messages<P> { + pub const fn new(db: SqlitePool, events: Broadcaster, publisher: P) -> Self { + Self { + db, + events, + publisher, + } } pub async fn delete( @@ -163,6 +116,113 @@ impl Messages { } } +impl<P> Messages<P> +where + P: Publish, +{ + pub async fn send( + &self, + conversation: &conversation::Id, + sender: &Login, + sent_at: &DateTime, + body: &Body, + ) -> Result<Message, SendError> { + let conversation_not_found = || SendError::ConversationNotFound(conversation.clone()); + let conversation_deleted = || SendError::ConversationDeleted(conversation.clone()); + let sender_not_found = || SendError::SenderNotFound(sender.id.clone().into()); + let sender_deleted = || SendError::SenderDeleted(sender.id.clone().into()); + + let mut tx = self.db.begin().await.fail(db::failed::BEGIN)?; + + let signer = tx + .vapid() + .signer() + .await + .fail("Failed to load VAPID signer")?; + let push_recipients = tx + .push() + .broadcast_from(sender) + .await + .fail("Failed to load push recipients")?; + + let conversation = tx + .conversations() + .by_id(conversation) + .await + .optional() + .fail("Failed to load conversation")? + .ok_or_else(conversation_not_found)?; + let sender = tx + .users() + .by_login(sender) + .await + .optional() + .fail("Failed to load sending user")? + .ok_or_else(sender_not_found)?; + + // Ordering: don't bother allocating a sequence number before we know the channel might + // exist. + let sent = tx + .sequence() + .next(sent_at) + .await + .fail("Failed to find event sequence number")?; + let conversation = conversation.as_of(sent).ok_or_else(conversation_deleted)?; + let sender = sender.as_of(sent).ok_or_else(sender_deleted)?; + let message = History::begin(&conversation, &sender, body, sent); + + // This filter technically includes every event in the history, but it's easier to follow if + // the various event-manipulating app methods are consistent, and it's harmless to have an + // always-satisfied filter. + let events = message.events().filter(Sequence::start_from(sent)); + tx.messages() + .record_events(events.clone()) + .await + .fail("Failed to store events")?; + + tx.commit().await.fail(db::failed::COMMIT)?; + + self.events.broadcast_from(events.clone()); + for event in events { + let failures = self + .publisher + .publish(event, &signer, &push_recipients) + .await + .fail("Failed to publish push events")?; + + if !failures.is_empty() { + let mut tx = self.db.begin().await.fail(db::failed::BEGIN)?; + // Note that data integrity guarantees from the original transaction to read + // subscriptions may no longer be valid now. Time has passed. Depending on how slow + // delivering push notifications is, potentially a _lot_ of time has passed. + + for (sub, err) in failures { + match err { + // I _think_ this is the complete set of permanent failures. See + // <https://docs.rs/web-push/latest/web_push/enum.WebPushError.html> for a complete + // list. + WebPushError::Unauthorized(_) + | WebPushError::InvalidUri + | WebPushError::EndpointNotValid(_) + | WebPushError::EndpointNotFound(_) + | WebPushError::InvalidCryptoKeys + | WebPushError::MissingCryptoKeys => { + tx.push().unsubscribe(sub).await.fail( + "Failed to unsubscribe after permanent push message rejection", + )?; + } + _ => (), + } + } + + tx.commit().await.fail(db::failed::COMMIT)?; + } + } + + Ok(message.as_sent()) + } +} + #[derive(Debug, thiserror::Error)] pub enum SendError { #[error("conversation {0} not found")] diff --git a/src/message/handlers/delete/mod.rs b/src/message/handlers/delete/mod.rs index c680db1..08a7cea 100644 --- a/src/message/handlers/delete/mod.rs +++ b/src/message/handlers/delete/mod.rs @@ -17,8 +17,8 @@ use crate::{ #[cfg(test)] mod test; -pub async fn handler( - State(messages): State<Messages>, +pub async fn handler<P>( + State(messages): State<Messages<P>>, Path(message): Path<message::Id>, RequestedAt(deleted_at): RequestedAt, identity: Identity, diff --git a/src/push/app.rs b/src/push/app.rs index ebfc220..f7846a6 100644 --- a/src/push/app.rs +++ b/src/push/app.rs @@ -101,7 +101,7 @@ where let failures = self .publisher - .publish(Heartbeat::Heartbeat, signer, subscriptions) + .publish(Heartbeat::Heartbeat, &signer, &subscriptions) .await .fail("Failed to send push message")?; diff --git a/src/push/handlers/ping/test.rs b/src/push/handlers/ping/test.rs index c985aaf..cc07ef0 100644 --- a/src/push/handlers/ping/test.rs +++ b/src/push/handlers/ping/test.rs @@ -26,7 +26,7 @@ async fn ping_without_subscriptions() { .sent() .into_iter() .filter(|publish| publish.message_eq(&Heartbeat::Heartbeat) - && publish.subscriptions.is_empty()) + && publish.recipients.is_empty()) .exactly_one() .is_ok() ); @@ -64,7 +64,7 @@ async fn ping() { .sent() .into_iter() .filter(|publish| publish.message_eq(&Heartbeat::Heartbeat) - && publish.subscriptions == subscriptions) + && publish.recipients == subscriptions) .exactly_one() .is_ok() ); @@ -110,7 +110,7 @@ async fn ping_multiple_subscriptions() { .sent() .into_iter() .filter(|publish| publish.message_eq(&Heartbeat::Heartbeat) - && publish.subscriptions == subscriptions) + && publish.recipients == subscriptions) .exactly_one() .is_ok() ); @@ -160,7 +160,7 @@ async fn ping_recipient_only() { assert!( sent.iter() .filter(|publish| publish.message_eq(&Heartbeat::Heartbeat) - && publish.subscriptions == recipient_subscriptions) + && publish.recipients == recipient_subscriptions) .exactly_one() .is_ok() ); @@ -169,7 +169,7 @@ async fn ping_recipient_only() { assert!( !sent .iter() - .any(|publish| publish.subscriptions.contains(&spectator_subscription)) + .any(|publish| publish.recipients.contains(&spectator_subscription)) ); } @@ -212,7 +212,7 @@ async fn ping_permanent_error() { assert!( sent.iter() .filter(|publish| publish.message_eq(&Heartbeat::Heartbeat) - && publish.subscriptions == subscriptions) + && publish.recipients == subscriptions) .exactly_one() .is_ok() ); @@ -230,7 +230,7 @@ async fn ping_permanent_error() { assert!( !sent .iter() - .any(|publish| publish.subscriptions.contains(&subscription)) + .any(|publish| publish.recipients.contains(&subscription)) ); } @@ -275,7 +275,7 @@ async fn ping_temporary_error() { assert!( sent.iter() .filter(|publish| publish.message_eq(&Heartbeat::Heartbeat) - && publish.subscriptions == subscriptions) + && publish.recipients == subscriptions) .exactly_one() .is_ok() ); @@ -293,7 +293,7 @@ async fn ping_temporary_error() { assert!( sent.iter() .filter(|publish| publish.message_eq(&Heartbeat::Heartbeat) - && publish.subscriptions == subscriptions) + && publish.recipients == subscriptions) .exactly_one() .is_ok() ); @@ -345,7 +345,7 @@ async fn ping_multiple_subscriptions_with_failure() { .sent() .iter() .filter(|publish| publish.message_eq(&Heartbeat::Heartbeat) - && publish.subscriptions == subscriptions) + && publish.recipients == subscriptions) .exactly_one() .is_ok() ); @@ -362,13 +362,13 @@ async fn ping_multiple_subscriptions_with_failure() { assert!( sent.iter() .filter(|publish| publish.message_eq(&Heartbeat::Heartbeat) - && publish.subscriptions == subscriptions) + && publish.recipients == subscriptions) .exactly_one() .is_ok() ); assert!( !sent .iter() - .any(|publish| publish.subscriptions.contains(&failing)) + .any(|publish| publish.recipients.contains(&failing)) ); } diff --git a/src/push/publisher.rs b/src/push/publisher.rs index 4092724..d6227a2 100644 --- a/src/push/publisher.rs +++ b/src/push/publisher.rs @@ -8,13 +8,14 @@ use web_push::{ use crate::error::failed::{Failed, ResultExt as _}; +#[async_trait::async_trait] pub trait Publish { - fn publish<M>( + async fn publish<'s, M>( &self, message: M, - signer: PartialVapidSignatureBuilder, - subscriptions: impl IntoIterator<Item = SubscriptionInfo> + Send, - ) -> impl Future<Output = Result<Vec<(SubscriptionInfo, WebPushError)>, Failed>> + Send + signer: &PartialVapidSignatureBuilder, + subscriptions: impl IntoIterator<Item = &'s SubscriptionInfo> + Send, + ) -> Result<Vec<(&'s SubscriptionInfo, WebPushError)>, Failed> where M: Serialize + Send + 'static; } @@ -50,13 +51,14 @@ impl Publisher { } } +#[async_trait::async_trait] impl Publish for Publisher { - async fn publish<M>( + async fn publish<'s, M>( &self, message: M, - signer: PartialVapidSignatureBuilder, - subscriptions: impl IntoIterator<Item = SubscriptionInfo> + Send, - ) -> Result<Vec<(SubscriptionInfo, WebPushError)>, Failed> + signer: &PartialVapidSignatureBuilder, + subscriptions: impl IntoIterator<Item = &'s SubscriptionInfo> + Send, + ) -> Result<Vec<(&'s SubscriptionInfo, WebPushError)>, Failed> where M: Serialize + Send + 'static, { @@ -65,7 +67,7 @@ impl Publish for Publisher { let messages: Vec<_> = subscriptions .into_iter() - .map(|sub| Self::prepare_message(&payload, &signer, &sub).map(|message| (sub, message))) + .map(|sub| Self::prepare_message(&payload, signer, sub).map(|message| (sub, message))) .try_collect()?; let deliveries = messages diff --git a/src/push/repo.rs b/src/push/repo.rs index 4183489..8850059 100644 --- a/src/push/repo.rs +++ b/src/push/repo.rs @@ -83,6 +83,27 @@ impl Push<'_> { Ok(info) } + pub async fn broadcast_from( + &mut self, + originator: &Login, + ) -> Result<Vec<SubscriptionInfo>, sqlx::Error> { + sqlx::query!( + r#" + select + sub.endpoint, + sub.p256dh, + sub.auth + from push_subscription as sub + join token on sub.token = token.id + where token.login <> $1 + "#, + originator.id, + ) + .map(|row| SubscriptionInfo::new(row.endpoint, row.p256dh, row.auth)) + .fetch_all(&mut *self.0) + .await + } + pub async fn unsubscribe( &mut self, subscription: &SubscriptionInfo, diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs index 0bd0b7a..39f5963 100644 --- a/src/test/fixtures/message.rs +++ b/src/test/fixtures/message.rs @@ -6,16 +6,18 @@ use crate::{ conversation::Conversation, login::Login, message::{self, Body, Message, app::Messages}, + push::Publish, }; -pub async fn send<App>( +pub async fn send<App, P>( app: &App, conversation: &Conversation, sender: &Login, sent_at: &RequestedAt, ) -> Message where - Messages: FromRef<App>, + Messages<P>: FromRef<App>, + P: Publish, { let body = propose(); diff --git a/src/test/webpush.rs b/src/test/webpush.rs index 96fa843..55caf19 100644 --- a/src/test/webpush.rs +++ b/src/test/webpush.rs @@ -2,7 +2,7 @@ use std::{ any::Any, collections::{HashMap, HashSet}, mem, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, MutexGuard}, }; use web_push::{PartialVapidSignatureBuilder, SubscriptionInfo, WebPushError}; @@ -10,59 +10,64 @@ use web_push::{PartialVapidSignatureBuilder, SubscriptionInfo, WebPushError}; use crate::{error::failed::Failed, push::Publish}; #[derive(Clone)] -pub struct Client { - sent: Arc<Mutex<Vec<Publication>>>, - failures: Arc<Mutex<HashMap<SubscriptionInfo, WebPushError>>>, +pub struct Client(Arc<Mutex<ClientInner>>); + +#[derive(Default)] +struct ClientInner { + sent: Vec<Publication>, + planned_failures: HashMap<SubscriptionInfo, WebPushError>, } impl Client { pub fn new() -> Self { - Self { - sent: Arc::default(), - failures: Arc::default(), - } + Self(Arc::default()) + } + + fn inner(&self) -> MutexGuard<'_, ClientInner> { + self.0.lock().unwrap() } // Clears the list of sent messages (for all clones of this Client) when called, because we // can't clone `Publications`s, so we either need to move them or try to reconstruct them. pub fn sent(&self) -> Vec<Publication> { - let mut sent = self.sent.lock().unwrap(); - mem::take(&mut sent) + let sent = &mut self.inner().sent; + mem::take(sent) } pub fn fail_next(&self, subscription_info: &SubscriptionInfo, err: WebPushError) { - let mut failures = self.failures.lock().unwrap(); - failures.insert(subscription_info.clone(), err); + let planned_failures = &mut self.inner().planned_failures; + planned_failures.insert(subscription_info.clone(), err); } } +#[async_trait::async_trait] impl Publish for Client { - async fn publish<M>( + async fn publish<'s, M>( &self, message: M, - _: PartialVapidSignatureBuilder, - subscriptions: impl IntoIterator<Item = SubscriptionInfo> + Send, - ) -> Result<Vec<(SubscriptionInfo, WebPushError)>, Failed> + _: &PartialVapidSignatureBuilder, + subscriptions: impl IntoIterator<Item = &'s SubscriptionInfo> + Send, + ) -> Result<Vec<(&'s SubscriptionInfo, WebPushError)>, Failed> where M: Send + 'static, { + let mut inner = self.inner(); let message: Box<dyn Any + Send> = Box::new(message); - let subscriptions = subscriptions.into_iter().collect(); + let mut recipients = HashSet::new(); let mut failures = Vec::new(); - - let mut planned_failures = self.failures.lock().unwrap(); - for subscription in &subscriptions { - if let Some(err) = planned_failures.remove(subscription) { - failures.push((subscription.clone(), err)); + for subscription in subscriptions { + recipients.insert(subscription.clone()); + if let Some(err) = inner.planned_failures.remove(subscription) { + failures.push((subscription, err)); } } let publication = Publication { message, - subscriptions, + recipients, }; - self.sent.lock().unwrap().push(publication); + inner.sent.push(publication); Ok(failures) } @@ -71,7 +76,7 @@ impl Publish for Client { #[derive(Debug)] pub struct Publication { pub message: Box<dyn Any + Send>, - pub subscriptions: HashSet<SubscriptionInfo>, + pub recipients: HashSet<SubscriptionInfo>, } impl Publication { |
