summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.sqlx/query-99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746.json32
-rw-r--r--Cargo.toml1
-rw-r--r--src/app.rs12
-rw-r--r--src/conversation/handlers/send/mod.rs10
-rw-r--r--src/expire.rs5
-rw-r--r--src/message/app.rs178
-rw-r--r--src/message/handlers/delete/mod.rs4
-rw-r--r--src/push/app.rs2
-rw-r--r--src/push/handlers/ping/test.rs24
-rw-r--r--src/push/publisher.rs20
-rw-r--r--src/push/repo.rs21
-rw-r--r--src/test/fixtures/message.rs6
-rw-r--r--src/test/webpush.rs55
13 files changed, 253 insertions, 117 deletions
diff --git a/.sqlx/query-99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746.json b/.sqlx/query-99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746.json
new file mode 100644
index 0000000..4bf6146
--- /dev/null
+++ b/.sqlx/query-99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746.json
@@ -0,0 +1,32 @@
+{
+ "db_name": "SQLite",
+ "query": "\n select\n sub.endpoint,\n sub.p256dh,\n sub.auth\n from push_subscription as sub\n join token on sub.token = token.id\n where token.login <> $1\n ",
+ "describe": {
+ "columns": [
+ {
+ "name": "endpoint",
+ "ordinal": 0,
+ "type_info": "Text"
+ },
+ {
+ "name": "p256dh",
+ "ordinal": 1,
+ "type_info": "Text"
+ },
+ {
+ "name": "auth",
+ "ordinal": 2,
+ "type_info": "Text"
+ }
+ ],
+ "parameters": {
+ "Right": 1
+ },
+ "nullable": [
+ false,
+ false,
+ false
+ ]
+ },
+ "hash": "99fc5ed0803a637ad140d6603462bc3ca82d28709a3117ce4844687098b19746"
+}
diff --git a/Cargo.toml b/Cargo.toml
index 4085a19..fa35201 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -22,6 +22,7 @@ assets = [
[dependencies]
argon2 = "0.5.3"
+async-trait = "0.1.89"
axum = { version = "0.8.6", features = ["form"] }
axum-extra = { version = "0.12.1", features = ["cookie", "query", "typed-header"] }
base64 = "0.22.1"
diff --git a/src/app.rs b/src/app.rs
index 6261d34..5eea20b 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -58,8 +58,11 @@ impl<P> App<P> {
Logins::new(self.db.clone(), self.token_events.clone())
}
- pub fn messages(&self) -> Messages {
- Messages::new(self.db.clone(), self.events.clone())
+ pub fn messages(&self) -> Messages<P>
+ where
+ P: Clone,
+ {
+ Messages::new(self.db.clone(), self.events.clone(), self.publisher.clone())
}
#[cfg(test)]
@@ -116,7 +119,10 @@ impl<P> FromRef<App<P>> for Logins {
}
}
-impl<P> FromRef<App<P>> for Messages {
+impl<P> FromRef<App<P>> for Messages<P>
+where
+ P: Clone,
+{
fn from_ref(app: &App<P>) -> Self {
app.messages()
}
diff --git a/src/conversation/handlers/send/mod.rs b/src/conversation/handlers/send/mod.rs
index 979dd24..67d3a70 100644
--- a/src/conversation/handlers/send/mod.rs
+++ b/src/conversation/handlers/send/mod.rs
@@ -12,19 +12,23 @@ use crate::{
Body, Message,
app::{Messages, SendError},
},
+ push::Publish,
token::extract::Identity,
};
#[cfg(test)]
mod test;
-pub async fn handler(
- State(messages): State<Messages>,
+pub async fn handler<P>(
+ State(messages): State<Messages<P>>,
Path(conversation): Path<PathInfo>,
RequestedAt(sent_at): RequestedAt,
identity: Identity,
Json(request): Json<Request>,
-) -> Result<Response, Error> {
+) -> Result<Response, Error>
+where
+ P: Publish,
+{
let message = messages
.send(&conversation, &identity.login, &sent_at, &request.body)
.await?;
diff --git a/src/expire.rs b/src/expire.rs
index c3b0117..da7ba53 100644
--- a/src/expire.rs
+++ b/src/expire.rs
@@ -12,7 +12,10 @@ pub async fn middleware<P>(
RequestedAt(expired_at): RequestedAt,
req: Request,
next: Next,
-) -> Result<Response, Internal> {
+) -> Result<Response, Internal>
+where
+ P: Clone,
+{
app.tokens().expire(&expired_at).await?;
app.invites().expire(&expired_at).await?;
app.messages().expire(&expired_at).await?;
diff --git a/src/message/app.rs b/src/message/app.rs
index b82fa83..8200650 100644
--- a/src/message/app.rs
+++ b/src/message/app.rs
@@ -1,6 +1,7 @@
use chrono::TimeDelta;
use itertools::Itertools;
use sqlx::sqlite::SqlitePool;
+use web_push::WebPushError;
use super::{Body, History, Id, Message, history, repo::Provider as _};
use crate::{
@@ -11,72 +12,24 @@ use crate::{
error::failed::{Failed, ResultExt as _},
event::{Broadcaster, Sequence, repo::Provider as _},
login::Login,
+ push::{Publish, repo::Provider as _},
user::{self, repo::Provider as _},
+ vapid::repo::Provider as _,
};
-pub struct Messages {
+pub struct Messages<P> {
db: SqlitePool,
events: Broadcaster,
+ publisher: P,
}
-impl Messages {
- pub const fn new(db: SqlitePool, events: Broadcaster) -> Self {
- Self { db, events }
- }
-
- pub async fn send(
- &self,
- conversation: &conversation::Id,
- sender: &Login,
- sent_at: &DateTime,
- body: &Body,
- ) -> Result<Message, SendError> {
- let conversation_not_found = || SendError::ConversationNotFound(conversation.clone());
- let conversation_deleted = || SendError::ConversationDeleted(conversation.clone());
- let sender_not_found = || SendError::SenderNotFound(sender.id.clone().into());
- let sender_deleted = || SendError::SenderDeleted(sender.id.clone().into());
-
- let mut tx = self.db.begin().await.fail(db::failed::BEGIN)?;
- let conversation = tx
- .conversations()
- .by_id(conversation)
- .await
- .optional()
- .fail("Failed to load conversation")?
- .ok_or_else(conversation_not_found)?;
- let sender = tx
- .users()
- .by_login(sender)
- .await
- .optional()
- .fail("Failed to load sending user")?
- .ok_or_else(sender_not_found)?;
-
- // Ordering: don't bother allocating a sequence number before we know the channel might
- // exist.
- let sent = tx
- .sequence()
- .next(sent_at)
- .await
- .fail("Failed to find event sequence number")?;
- let conversation = conversation.as_of(sent).ok_or_else(conversation_deleted)?;
- let sender = sender.as_of(sent).ok_or_else(sender_deleted)?;
- let message = History::begin(&conversation, &sender, body, sent);
-
- // This filter technically includes every event in the history, but it's easier to follow if
- // the various event-manipulating app methods are consistent, and it's harmless to have an
- // always-satisfied filter.
- let events = message.events().filter(Sequence::start_from(sent));
- tx.messages()
- .record_events(events.clone())
- .await
- .fail("Failed to store events")?;
-
- tx.commit().await.fail(db::failed::COMMIT)?;
-
- self.events.broadcast_from(events);
-
- Ok(message.as_sent())
+impl<P> Messages<P> {
+ pub const fn new(db: SqlitePool, events: Broadcaster, publisher: P) -> Self {
+ Self {
+ db,
+ events,
+ publisher,
+ }
}
pub async fn delete(
@@ -163,6 +116,113 @@ impl Messages {
}
}
+impl<P> Messages<P>
+where
+ P: Publish,
+{
+ pub async fn send(
+ &self,
+ conversation: &conversation::Id,
+ sender: &Login,
+ sent_at: &DateTime,
+ body: &Body,
+ ) -> Result<Message, SendError> {
+ let conversation_not_found = || SendError::ConversationNotFound(conversation.clone());
+ let conversation_deleted = || SendError::ConversationDeleted(conversation.clone());
+ let sender_not_found = || SendError::SenderNotFound(sender.id.clone().into());
+ let sender_deleted = || SendError::SenderDeleted(sender.id.clone().into());
+
+ let mut tx = self.db.begin().await.fail(db::failed::BEGIN)?;
+
+ let signer = tx
+ .vapid()
+ .signer()
+ .await
+ .fail("Failed to load VAPID signer")?;
+ let push_recipients = tx
+ .push()
+ .broadcast_from(sender)
+ .await
+ .fail("Failed to load push recipients")?;
+
+ let conversation = tx
+ .conversations()
+ .by_id(conversation)
+ .await
+ .optional()
+ .fail("Failed to load conversation")?
+ .ok_or_else(conversation_not_found)?;
+ let sender = tx
+ .users()
+ .by_login(sender)
+ .await
+ .optional()
+ .fail("Failed to load sending user")?
+ .ok_or_else(sender_not_found)?;
+
+ // Ordering: don't bother allocating a sequence number before we know the channel might
+ // exist.
+ let sent = tx
+ .sequence()
+ .next(sent_at)
+ .await
+ .fail("Failed to find event sequence number")?;
+ let conversation = conversation.as_of(sent).ok_or_else(conversation_deleted)?;
+ let sender = sender.as_of(sent).ok_or_else(sender_deleted)?;
+ let message = History::begin(&conversation, &sender, body, sent);
+
+ // This filter technically includes every event in the history, but it's easier to follow if
+ // the various event-manipulating app methods are consistent, and it's harmless to have an
+ // always-satisfied filter.
+ let events = message.events().filter(Sequence::start_from(sent));
+ tx.messages()
+ .record_events(events.clone())
+ .await
+ .fail("Failed to store events")?;
+
+ tx.commit().await.fail(db::failed::COMMIT)?;
+
+ self.events.broadcast_from(events.clone());
+ for event in events {
+ let failures = self
+ .publisher
+ .publish(event, &signer, &push_recipients)
+ .await
+ .fail("Failed to publish push events")?;
+
+ if !failures.is_empty() {
+ let mut tx = self.db.begin().await.fail(db::failed::BEGIN)?;
+ // Note that data integrity guarantees from the original transaction to read
+ // subscriptions may no longer be valid now. Time has passed. Depending on how slow
+ // delivering push notifications is, potentially a _lot_ of time has passed.
+
+ for (sub, err) in failures {
+ match err {
+ // I _think_ this is the complete set of permanent failures. See
+ // <https://docs.rs/web-push/latest/web_push/enum.WebPushError.html> for a complete
+ // list.
+ WebPushError::Unauthorized(_)
+ | WebPushError::InvalidUri
+ | WebPushError::EndpointNotValid(_)
+ | WebPushError::EndpointNotFound(_)
+ | WebPushError::InvalidCryptoKeys
+ | WebPushError::MissingCryptoKeys => {
+ tx.push().unsubscribe(sub).await.fail(
+ "Failed to unsubscribe after permanent push message rejection",
+ )?;
+ }
+ _ => (),
+ }
+ }
+
+ tx.commit().await.fail(db::failed::COMMIT)?;
+ }
+ }
+
+ Ok(message.as_sent())
+ }
+}
+
#[derive(Debug, thiserror::Error)]
pub enum SendError {
#[error("conversation {0} not found")]
diff --git a/src/message/handlers/delete/mod.rs b/src/message/handlers/delete/mod.rs
index c680db1..08a7cea 100644
--- a/src/message/handlers/delete/mod.rs
+++ b/src/message/handlers/delete/mod.rs
@@ -17,8 +17,8 @@ use crate::{
#[cfg(test)]
mod test;
-pub async fn handler(
- State(messages): State<Messages>,
+pub async fn handler<P>(
+ State(messages): State<Messages<P>>,
Path(message): Path<message::Id>,
RequestedAt(deleted_at): RequestedAt,
identity: Identity,
diff --git a/src/push/app.rs b/src/push/app.rs
index ebfc220..f7846a6 100644
--- a/src/push/app.rs
+++ b/src/push/app.rs
@@ -101,7 +101,7 @@ where
let failures = self
.publisher
- .publish(Heartbeat::Heartbeat, signer, subscriptions)
+ .publish(Heartbeat::Heartbeat, &signer, &subscriptions)
.await
.fail("Failed to send push message")?;
diff --git a/src/push/handlers/ping/test.rs b/src/push/handlers/ping/test.rs
index c985aaf..cc07ef0 100644
--- a/src/push/handlers/ping/test.rs
+++ b/src/push/handlers/ping/test.rs
@@ -26,7 +26,7 @@ async fn ping_without_subscriptions() {
.sent()
.into_iter()
.filter(|publish| publish.message_eq(&Heartbeat::Heartbeat)
- && publish.subscriptions.is_empty())
+ && publish.recipients.is_empty())
.exactly_one()
.is_ok()
);
@@ -64,7 +64,7 @@ async fn ping() {
.sent()
.into_iter()
.filter(|publish| publish.message_eq(&Heartbeat::Heartbeat)
- && publish.subscriptions == subscriptions)
+ && publish.recipients == subscriptions)
.exactly_one()
.is_ok()
);
@@ -110,7 +110,7 @@ async fn ping_multiple_subscriptions() {
.sent()
.into_iter()
.filter(|publish| publish.message_eq(&Heartbeat::Heartbeat)
- && publish.subscriptions == subscriptions)
+ && publish.recipients == subscriptions)
.exactly_one()
.is_ok()
);
@@ -160,7 +160,7 @@ async fn ping_recipient_only() {
assert!(
sent.iter()
.filter(|publish| publish.message_eq(&Heartbeat::Heartbeat)
- && publish.subscriptions == recipient_subscriptions)
+ && publish.recipients == recipient_subscriptions)
.exactly_one()
.is_ok()
);
@@ -169,7 +169,7 @@ async fn ping_recipient_only() {
assert!(
!sent
.iter()
- .any(|publish| publish.subscriptions.contains(&spectator_subscription))
+ .any(|publish| publish.recipients.contains(&spectator_subscription))
);
}
@@ -212,7 +212,7 @@ async fn ping_permanent_error() {
assert!(
sent.iter()
.filter(|publish| publish.message_eq(&Heartbeat::Heartbeat)
- && publish.subscriptions == subscriptions)
+ && publish.recipients == subscriptions)
.exactly_one()
.is_ok()
);
@@ -230,7 +230,7 @@ async fn ping_permanent_error() {
assert!(
!sent
.iter()
- .any(|publish| publish.subscriptions.contains(&subscription))
+ .any(|publish| publish.recipients.contains(&subscription))
);
}
@@ -275,7 +275,7 @@ async fn ping_temporary_error() {
assert!(
sent.iter()
.filter(|publish| publish.message_eq(&Heartbeat::Heartbeat)
- && publish.subscriptions == subscriptions)
+ && publish.recipients == subscriptions)
.exactly_one()
.is_ok()
);
@@ -293,7 +293,7 @@ async fn ping_temporary_error() {
assert!(
sent.iter()
.filter(|publish| publish.message_eq(&Heartbeat::Heartbeat)
- && publish.subscriptions == subscriptions)
+ && publish.recipients == subscriptions)
.exactly_one()
.is_ok()
);
@@ -345,7 +345,7 @@ async fn ping_multiple_subscriptions_with_failure() {
.sent()
.iter()
.filter(|publish| publish.message_eq(&Heartbeat::Heartbeat)
- && publish.subscriptions == subscriptions)
+ && publish.recipients == subscriptions)
.exactly_one()
.is_ok()
);
@@ -362,13 +362,13 @@ async fn ping_multiple_subscriptions_with_failure() {
assert!(
sent.iter()
.filter(|publish| publish.message_eq(&Heartbeat::Heartbeat)
- && publish.subscriptions == subscriptions)
+ && publish.recipients == subscriptions)
.exactly_one()
.is_ok()
);
assert!(
!sent
.iter()
- .any(|publish| publish.subscriptions.contains(&failing))
+ .any(|publish| publish.recipients.contains(&failing))
);
}
diff --git a/src/push/publisher.rs b/src/push/publisher.rs
index 4092724..d6227a2 100644
--- a/src/push/publisher.rs
+++ b/src/push/publisher.rs
@@ -8,13 +8,14 @@ use web_push::{
use crate::error::failed::{Failed, ResultExt as _};
+#[async_trait::async_trait]
pub trait Publish {
- fn publish<M>(
+ async fn publish<'s, M>(
&self,
message: M,
- signer: PartialVapidSignatureBuilder,
- subscriptions: impl IntoIterator<Item = SubscriptionInfo> + Send,
- ) -> impl Future<Output = Result<Vec<(SubscriptionInfo, WebPushError)>, Failed>> + Send
+ signer: &PartialVapidSignatureBuilder,
+ subscriptions: impl IntoIterator<Item = &'s SubscriptionInfo> + Send,
+ ) -> Result<Vec<(&'s SubscriptionInfo, WebPushError)>, Failed>
where
M: Serialize + Send + 'static;
}
@@ -50,13 +51,14 @@ impl Publisher {
}
}
+#[async_trait::async_trait]
impl Publish for Publisher {
- async fn publish<M>(
+ async fn publish<'s, M>(
&self,
message: M,
- signer: PartialVapidSignatureBuilder,
- subscriptions: impl IntoIterator<Item = SubscriptionInfo> + Send,
- ) -> Result<Vec<(SubscriptionInfo, WebPushError)>, Failed>
+ signer: &PartialVapidSignatureBuilder,
+ subscriptions: impl IntoIterator<Item = &'s SubscriptionInfo> + Send,
+ ) -> Result<Vec<(&'s SubscriptionInfo, WebPushError)>, Failed>
where
M: Serialize + Send + 'static,
{
@@ -65,7 +67,7 @@ impl Publish for Publisher {
let messages: Vec<_> = subscriptions
.into_iter()
- .map(|sub| Self::prepare_message(&payload, &signer, &sub).map(|message| (sub, message)))
+ .map(|sub| Self::prepare_message(&payload, signer, sub).map(|message| (sub, message)))
.try_collect()?;
let deliveries = messages
diff --git a/src/push/repo.rs b/src/push/repo.rs
index 4183489..8850059 100644
--- a/src/push/repo.rs
+++ b/src/push/repo.rs
@@ -83,6 +83,27 @@ impl Push<'_> {
Ok(info)
}
+ pub async fn broadcast_from(
+ &mut self,
+ originator: &Login,
+ ) -> Result<Vec<SubscriptionInfo>, sqlx::Error> {
+ sqlx::query!(
+ r#"
+ select
+ sub.endpoint,
+ sub.p256dh,
+ sub.auth
+ from push_subscription as sub
+ join token on sub.token = token.id
+ where token.login <> $1
+ "#,
+ originator.id,
+ )
+ .map(|row| SubscriptionInfo::new(row.endpoint, row.p256dh, row.auth))
+ .fetch_all(&mut *self.0)
+ .await
+ }
+
pub async fn unsubscribe(
&mut self,
subscription: &SubscriptionInfo,
diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs
index 0bd0b7a..39f5963 100644
--- a/src/test/fixtures/message.rs
+++ b/src/test/fixtures/message.rs
@@ -6,16 +6,18 @@ use crate::{
conversation::Conversation,
login::Login,
message::{self, Body, Message, app::Messages},
+ push::Publish,
};
-pub async fn send<App>(
+pub async fn send<App, P>(
app: &App,
conversation: &Conversation,
sender: &Login,
sent_at: &RequestedAt,
) -> Message
where
- Messages: FromRef<App>,
+ Messages<P>: FromRef<App>,
+ P: Publish,
{
let body = propose();
diff --git a/src/test/webpush.rs b/src/test/webpush.rs
index 96fa843..55caf19 100644
--- a/src/test/webpush.rs
+++ b/src/test/webpush.rs
@@ -2,7 +2,7 @@ use std::{
any::Any,
collections::{HashMap, HashSet},
mem,
- sync::{Arc, Mutex},
+ sync::{Arc, Mutex, MutexGuard},
};
use web_push::{PartialVapidSignatureBuilder, SubscriptionInfo, WebPushError};
@@ -10,59 +10,64 @@ use web_push::{PartialVapidSignatureBuilder, SubscriptionInfo, WebPushError};
use crate::{error::failed::Failed, push::Publish};
#[derive(Clone)]
-pub struct Client {
- sent: Arc<Mutex<Vec<Publication>>>,
- failures: Arc<Mutex<HashMap<SubscriptionInfo, WebPushError>>>,
+pub struct Client(Arc<Mutex<ClientInner>>);
+
+#[derive(Default)]
+struct ClientInner {
+ sent: Vec<Publication>,
+ planned_failures: HashMap<SubscriptionInfo, WebPushError>,
}
impl Client {
pub fn new() -> Self {
- Self {
- sent: Arc::default(),
- failures: Arc::default(),
- }
+ Self(Arc::default())
+ }
+
+ fn inner(&self) -> MutexGuard<'_, ClientInner> {
+ self.0.lock().unwrap()
}
// Clears the list of sent messages (for all clones of this Client) when called, because we
// can't clone `Publications`s, so we either need to move them or try to reconstruct them.
pub fn sent(&self) -> Vec<Publication> {
- let mut sent = self.sent.lock().unwrap();
- mem::take(&mut sent)
+ let sent = &mut self.inner().sent;
+ mem::take(sent)
}
pub fn fail_next(&self, subscription_info: &SubscriptionInfo, err: WebPushError) {
- let mut failures = self.failures.lock().unwrap();
- failures.insert(subscription_info.clone(), err);
+ let planned_failures = &mut self.inner().planned_failures;
+ planned_failures.insert(subscription_info.clone(), err);
}
}
+#[async_trait::async_trait]
impl Publish for Client {
- async fn publish<M>(
+ async fn publish<'s, M>(
&self,
message: M,
- _: PartialVapidSignatureBuilder,
- subscriptions: impl IntoIterator<Item = SubscriptionInfo> + Send,
- ) -> Result<Vec<(SubscriptionInfo, WebPushError)>, Failed>
+ _: &PartialVapidSignatureBuilder,
+ subscriptions: impl IntoIterator<Item = &'s SubscriptionInfo> + Send,
+ ) -> Result<Vec<(&'s SubscriptionInfo, WebPushError)>, Failed>
where
M: Send + 'static,
{
+ let mut inner = self.inner();
let message: Box<dyn Any + Send> = Box::new(message);
- let subscriptions = subscriptions.into_iter().collect();
+ let mut recipients = HashSet::new();
let mut failures = Vec::new();
-
- let mut planned_failures = self.failures.lock().unwrap();
- for subscription in &subscriptions {
- if let Some(err) = planned_failures.remove(subscription) {
- failures.push((subscription.clone(), err));
+ for subscription in subscriptions {
+ recipients.insert(subscription.clone());
+ if let Some(err) = inner.planned_failures.remove(subscription) {
+ failures.push((subscription, err));
}
}
let publication = Publication {
message,
- subscriptions,
+ recipients,
};
- self.sent.lock().unwrap().push(publication);
+ inner.sent.push(publication);
Ok(failures)
}
@@ -71,7 +76,7 @@ impl Publish for Client {
#[derive(Debug)]
pub struct Publication {
pub message: Box<dyn Any + Send>,
- pub subscriptions: HashSet<SubscriptionInfo>,
+ pub recipients: HashSet<SubscriptionInfo>,
}
impl Publication {