summaryrefslogtreecommitdiff
path: root/src/push/app.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/push/app.rs')
-rw-r--r--src/push/app.rs114
1 files changed, 108 insertions, 6 deletions
diff --git a/src/push/app.rs b/src/push/app.rs
index 358a8cc..56b9a02 100644
--- a/src/push/app.rs
+++ b/src/push/app.rs
@@ -1,17 +1,23 @@
+use futures::future::join_all;
+use itertools::Itertools as _;
use p256::ecdsa::VerifyingKey;
use sqlx::SqlitePool;
-use web_push::SubscriptionInfo;
+use web_push::{
+ ContentEncoding, PartialVapidSignatureBuilder, SubscriptionInfo, WebPushClient, WebPushError,
+ WebPushMessage, WebPushMessageBuilder,
+};
use super::repo::Provider as _;
-use crate::{token::extract::Identity, vapid, vapid::repo::Provider as _};
+use crate::{login::Login, token::extract::Identity, vapid, vapid::repo::Provider as _};
-pub struct Push {
+pub struct Push<P> {
db: SqlitePool,
+ webpush: P,
}
-impl Push {
- pub const fn new(db: SqlitePool) -> Self {
- Self { db }
+impl<P> Push<P> {
+ pub const fn new(db: SqlitePool, webpush: P) -> Self {
+ Self { db, webpush }
}
pub async fn subscribe(
@@ -60,6 +66,76 @@ impl Push {
}
}
+impl<P> Push<P>
+where
+ P: WebPushClient,
+{
+ fn prepare_ping(
+ signer: &PartialVapidSignatureBuilder,
+ subscription: &SubscriptionInfo,
+ ) -> Result<WebPushMessage, WebPushError> {
+ let signature = signer.clone().add_sub_info(subscription).build()?;
+
+ let payload = "ping".as_bytes();
+
+ let mut message = WebPushMessageBuilder::new(subscription);
+ message.set_payload(ContentEncoding::Aes128Gcm, payload);
+ message.set_vapid_signature(signature);
+ let message = message.build()?;
+
+ Ok(message)
+ }
+
+ pub async fn ping(&self, recipient: &Login) -> Result<(), PushError> {
+ let mut tx = self.db.begin().await?;
+
+ let signer = tx.vapid().signer().await?;
+ let subscriptions = tx.push().by_login(recipient).await?;
+
+ let pings: Vec<_> = subscriptions
+ .into_iter()
+ .map(|sub| Self::prepare_ping(&signer, &sub).map(|message| (sub, message)))
+ .try_collect()?;
+
+ let deliveries = pings
+ .into_iter()
+ .map(async |(sub, message)| (sub, self.webpush.send(message).await));
+
+ let failures: Vec<_> = join_all(deliveries)
+ .await
+ .into_iter()
+ .filter_map(|(sub, result)| result.err().map(|err| (sub, err)))
+ .collect();
+
+ if !failures.is_empty() {
+ for (sub, err) in &failures {
+ match err {
+ // I _think_ this is the complete set of permanent failures. See
+ // <https://docs.rs/web-push/latest/web_push/enum.WebPushError.html> for a complete
+ // list.
+ WebPushError::Unauthorized(_)
+ | WebPushError::InvalidUri
+ | WebPushError::EndpointNotValid(_)
+ | WebPushError::EndpointNotFound(_)
+ | WebPushError::InvalidCryptoKeys
+ | WebPushError::MissingCryptoKeys => {
+ tx.push().unsubscribe(sub).await?;
+ }
+ _ => (),
+ }
+ }
+
+ return Err(PushError::Delivery(
+ failures.into_iter().map(|(_, err)| err).collect(),
+ ));
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+}
+
#[derive(Debug, thiserror::Error)]
pub enum SubscribeError {
#[error(transparent)]
@@ -74,3 +150,29 @@ pub enum SubscribeError {
// client, which already knows the endpoint anyways and doesn't need us to tell them.
Duplicate,
}
+
+#[derive(Debug, thiserror::Error)]
+pub enum PushError {
+ #[error(transparent)]
+ Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Ecdsa(#[from] p256::ecdsa::Error),
+ #[error(transparent)]
+ Pkcs8(#[from] p256::pkcs8::Error),
+ #[error(transparent)]
+ WebPush(#[from] WebPushError),
+ #[error("push message delivery failures: {0:?}")]
+ Delivery(Vec<WebPushError>),
+}
+
+impl From<vapid::repo::Error> for PushError {
+ fn from(error: vapid::repo::Error) -> Self {
+ use vapid::repo::Error;
+ match error {
+ Error::Database(error) => error.into(),
+ Error::Ecdsa(error) => error.into(),
+ Error::Pkcs8(error) => error.into(),
+ Error::WebPush(error) => error.into(),
+ }
+ }
+}