summaryrefslogtreecommitdiff
path: root/src/push
diff options
context:
space:
mode:
Diffstat (limited to 'src/push')
-rw-r--r--src/push/app.rs114
-rw-r--r--src/push/handlers/mod.rs2
-rw-r--r--src/push/handlers/ping/mod.rs23
-rw-r--r--src/push/handlers/ping/test.rs40
-rw-r--r--src/push/handlers/subscribe/mod.rs7
-rw-r--r--src/push/repo.rs35
6 files changed, 211 insertions, 10 deletions
diff --git a/src/push/app.rs b/src/push/app.rs
index 358a8cc..56b9a02 100644
--- a/src/push/app.rs
+++ b/src/push/app.rs
@@ -1,17 +1,23 @@
+use futures::future::join_all;
+use itertools::Itertools as _;
use p256::ecdsa::VerifyingKey;
use sqlx::SqlitePool;
-use web_push::SubscriptionInfo;
+use web_push::{
+ ContentEncoding, PartialVapidSignatureBuilder, SubscriptionInfo, WebPushClient, WebPushError,
+ WebPushMessage, WebPushMessageBuilder,
+};
use super::repo::Provider as _;
-use crate::{token::extract::Identity, vapid, vapid::repo::Provider as _};
+use crate::{login::Login, token::extract::Identity, vapid, vapid::repo::Provider as _};
-pub struct Push {
+pub struct Push<P> {
db: SqlitePool,
+ webpush: P,
}
-impl Push {
- pub const fn new(db: SqlitePool) -> Self {
- Self { db }
+impl<P> Push<P> {
+ pub const fn new(db: SqlitePool, webpush: P) -> Self {
+ Self { db, webpush }
}
pub async fn subscribe(
@@ -60,6 +66,76 @@ impl Push {
}
}
+impl<P> Push<P>
+where
+ P: WebPushClient,
+{
+ fn prepare_ping(
+ signer: &PartialVapidSignatureBuilder,
+ subscription: &SubscriptionInfo,
+ ) -> Result<WebPushMessage, WebPushError> {
+ let signature = signer.clone().add_sub_info(subscription).build()?;
+
+ let payload = "ping".as_bytes();
+
+ let mut message = WebPushMessageBuilder::new(subscription);
+ message.set_payload(ContentEncoding::Aes128Gcm, payload);
+ message.set_vapid_signature(signature);
+ let message = message.build()?;
+
+ Ok(message)
+ }
+
+ pub async fn ping(&self, recipient: &Login) -> Result<(), PushError> {
+ let mut tx = self.db.begin().await?;
+
+ let signer = tx.vapid().signer().await?;
+ let subscriptions = tx.push().by_login(recipient).await?;
+
+ let pings: Vec<_> = subscriptions
+ .into_iter()
+ .map(|sub| Self::prepare_ping(&signer, &sub).map(|message| (sub, message)))
+ .try_collect()?;
+
+ let deliveries = pings
+ .into_iter()
+ .map(async |(sub, message)| (sub, self.webpush.send(message).await));
+
+ let failures: Vec<_> = join_all(deliveries)
+ .await
+ .into_iter()
+ .filter_map(|(sub, result)| result.err().map(|err| (sub, err)))
+ .collect();
+
+ if !failures.is_empty() {
+ for (sub, err) in &failures {
+ match err {
+ // I _think_ this is the complete set of permanent failures. See
+ // <https://docs.rs/web-push/latest/web_push/enum.WebPushError.html> for a complete
+ // list.
+ WebPushError::Unauthorized(_)
+ | WebPushError::InvalidUri
+ | WebPushError::EndpointNotValid(_)
+ | WebPushError::EndpointNotFound(_)
+ | WebPushError::InvalidCryptoKeys
+ | WebPushError::MissingCryptoKeys => {
+ tx.push().unsubscribe(sub).await?;
+ }
+ _ => (),
+ }
+ }
+
+ return Err(PushError::Delivery(
+ failures.into_iter().map(|(_, err)| err).collect(),
+ ));
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+}
+
#[derive(Debug, thiserror::Error)]
pub enum SubscribeError {
#[error(transparent)]
@@ -74,3 +150,29 @@ pub enum SubscribeError {
// client, which already knows the endpoint anyways and doesn't need us to tell them.
Duplicate,
}
+
+#[derive(Debug, thiserror::Error)]
+pub enum PushError {
+ #[error(transparent)]
+ Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Ecdsa(#[from] p256::ecdsa::Error),
+ #[error(transparent)]
+ Pkcs8(#[from] p256::pkcs8::Error),
+ #[error(transparent)]
+ WebPush(#[from] WebPushError),
+ #[error("push message delivery failures: {0:?}")]
+ Delivery(Vec<WebPushError>),
+}
+
+impl From<vapid::repo::Error> for PushError {
+ fn from(error: vapid::repo::Error) -> Self {
+ use vapid::repo::Error;
+ match error {
+ Error::Database(error) => error.into(),
+ Error::Ecdsa(error) => error.into(),
+ Error::Pkcs8(error) => error.into(),
+ Error::WebPush(error) => error.into(),
+ }
+ }
+}
diff --git a/src/push/handlers/mod.rs b/src/push/handlers/mod.rs
index 86eeea0..bb58774 100644
--- a/src/push/handlers/mod.rs
+++ b/src/push/handlers/mod.rs
@@ -1,3 +1,5 @@
+mod ping;
mod subscribe;
+pub use ping::handler as ping;
pub use subscribe::handler as subscribe;
diff --git a/src/push/handlers/ping/mod.rs b/src/push/handlers/ping/mod.rs
new file mode 100644
index 0000000..db828fa
--- /dev/null
+++ b/src/push/handlers/ping/mod.rs
@@ -0,0 +1,23 @@
+use axum::{Json, extract::State, http::StatusCode};
+use web_push::WebPushClient;
+
+use crate::{error::Internal, push::app::Push, token::extract::Identity};
+
+#[cfg(test)]
+mod test;
+
+#[derive(serde::Deserialize)]
+pub struct Request {}
+
+pub async fn handler<P>(
+ State(push): State<Push<P>>,
+ identity: Identity,
+ Json(_): Json<Request>,
+) -> Result<StatusCode, Internal>
+where
+ P: WebPushClient,
+{
+ push.ping(&identity.login).await?;
+
+ Ok(StatusCode::ACCEPTED)
+}
diff --git a/src/push/handlers/ping/test.rs b/src/push/handlers/ping/test.rs
new file mode 100644
index 0000000..5725131
--- /dev/null
+++ b/src/push/handlers/ping/test.rs
@@ -0,0 +1,40 @@
+use axum::{
+ extract::{Json, State},
+ http::StatusCode,
+};
+
+use crate::test::fixtures;
+
+#[tokio::test]
+async fn ping_without_subscriptions() {
+ let app = fixtures::scratch_app().await;
+
+ let recipient = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ let response = super::handler(State(app.push()), recipient, Json(super::Request {}))
+ .await
+ .expect("sending a ping with no subscriptions always succeeds");
+
+ assert_eq!(StatusCode::ACCEPTED, response);
+
+ assert!(app.webpush().sent().is_empty());
+}
+
+// More complete testing requires that we figure out how to generate working p256 ECDH keys for
+// testing _with_, as `web_push` will actually parse and use those keys even if push messages are
+// ultimately never serialized or sent over HTTP.
+//
+// Tests that are missing:
+//
+// * Verify that subscribing and sending a ping causes a ping to be delivered to that subscription.
+// * Verify that two subscriptions both get pings.
+// * Verify that other users' subscriptions are not pinged.
+// * Verify that a ping that causes a permanent error causes the subscription to be deleted.
+// * Verify that a ping that causes a non-permanent error does not cause the subscription to be
+// deleted.
+// * Verify that a failure on one subscription doesn't affect delivery on other subscriptions.
diff --git a/src/push/handlers/subscribe/mod.rs b/src/push/handlers/subscribe/mod.rs
index d142df6..a1a5899 100644
--- a/src/push/handlers/subscribe/mod.rs
+++ b/src/push/handlers/subscribe/mod.rs
@@ -36,8 +36,8 @@ pub struct Keys {
auth: String,
}
-pub async fn handler(
- State(push): State<Push>,
+pub async fn handler<P>(
+ State(push): State<Push<P>>,
identity: Identity,
Json(request): Json<Request>,
) -> Result<StatusCode, Error> {
@@ -58,8 +58,7 @@ impl From<Subscription> for SubscriptionInfo {
endpoint,
keys: Keys { p256dh, auth },
} = request;
- let info = SubscriptionInfo::new(endpoint, p256dh, auth);
- info
+ SubscriptionInfo::new(endpoint, p256dh, auth)
}
}
diff --git a/src/push/repo.rs b/src/push/repo.rs
index 6c18c6e..4183489 100644
--- a/src/push/repo.rs
+++ b/src/push/repo.rs
@@ -37,6 +37,24 @@ impl Push<'_> {
Ok(())
}
+ pub async fn by_login(&mut self, login: &Login) -> Result<Vec<SubscriptionInfo>, sqlx::Error> {
+ sqlx::query!(
+ r#"
+ select
+ subscription.endpoint,
+ subscription.p256dh,
+ subscription.auth
+ from push_subscription as subscription
+ join token on subscription.token = token.id
+ where token.login = $1
+ "#,
+ login.id,
+ )
+ .map(|row| SubscriptionInfo::new(row.endpoint, row.p256dh, row.auth))
+ .fetch_all(&mut *self.0)
+ .await
+ }
+
pub async fn by_endpoint(
&mut self,
subscriber: &Login,
@@ -65,6 +83,23 @@ impl Push<'_> {
Ok(info)
}
+ pub async fn unsubscribe(
+ &mut self,
+ subscription: &SubscriptionInfo,
+ ) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ delete from push_subscription
+ where endpoint = $1
+ "#,
+ subscription.endpoint,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
pub async fn unsubscribe_token(&mut self, token: &Token) -> Result<(), sqlx::Error> {
sqlx::query!(
r#"