use futures::future::join_all; use itertools::Itertools as _; use p256::ecdsa::VerifyingKey; use sqlx::SqlitePool; use web_push::{ ContentEncoding, PartialVapidSignatureBuilder, SubscriptionInfo, WebPushClient, WebPushError, WebPushMessage, WebPushMessageBuilder, }; use super::repo::Provider as _; use crate::{login::Login, token::extract::Identity, vapid, vapid::repo::Provider as _}; pub struct Push

{ db: SqlitePool, webpush: P, } impl

Push

{ pub const fn new(db: SqlitePool, webpush: P) -> Self { Self { db, webpush } } pub async fn subscribe( &self, subscriber: &Identity, subscription: &SubscriptionInfo, vapid: &VerifyingKey, ) -> Result<(), SubscribeError> { let mut tx = self.db.begin().await?; let current = tx.vapid().current().await?; if vapid != ¤t.key { return Err(SubscribeError::StaleVapidKey(current.key)); } match tx.push().create(&subscriber.token, subscription).await { Ok(()) => (), Err(err) => { if let Some(err) = err.as_database_error() && err.is_unique_violation() { let current = tx .push() .by_endpoint(&subscriber.login, &subscription.endpoint) .await?; // If we already have a subscription for this endpoint, with _different_ // parameters, then this is a client error. They shouldn't reuse endpoint URLs, // per the various RFCs. // // However, if we have a subscription for this endpoint with the same parameters // then we accept it and silently do nothing. This may happen if, for example, // the subscribe request is retried due to a network interruption where it's // not clear whether the original request succeeded. if ¤t != subscription { return Err(SubscribeError::Duplicate); } } else { return Err(SubscribeError::Database(err)); } } } tx.commit().await?; Ok(()) } } impl

Push

where P: WebPushClient, { fn prepare_ping( signer: &PartialVapidSignatureBuilder, subscription: &SubscriptionInfo, ) -> Result { let signature = signer.clone().add_sub_info(subscription).build()?; let payload = "ping".as_bytes(); let mut message = WebPushMessageBuilder::new(subscription); message.set_payload(ContentEncoding::Aes128Gcm, payload); message.set_vapid_signature(signature); let message = message.build()?; Ok(message) } pub async fn ping(&self, recipient: &Login) -> Result<(), PushError> { let mut tx = self.db.begin().await?; let signer = tx.vapid().signer().await?; let subscriptions = tx.push().by_login(recipient).await?; let pings: Vec<_> = subscriptions .into_iter() .map(|sub| Self::prepare_ping(&signer, &sub).map(|message| (sub, message))) .try_collect()?; let deliveries = pings .into_iter() .map(async |(sub, message)| (sub, self.webpush.send(message).await)); let failures: Vec<_> = join_all(deliveries) .await .into_iter() .filter_map(|(sub, result)| result.err().map(|err| (sub, err))) .collect(); if !failures.is_empty() { for (sub, err) in &failures { match err { // I _think_ this is the complete set of permanent failures. See // for a complete // list. WebPushError::Unauthorized(_) | WebPushError::InvalidUri | WebPushError::EndpointNotValid(_) | WebPushError::EndpointNotFound(_) | WebPushError::InvalidCryptoKeys | WebPushError::MissingCryptoKeys => { tx.push().unsubscribe(sub).await?; } _ => (), } } return Err(PushError::Delivery( failures.into_iter().map(|(_, err)| err).collect(), )); } tx.commit().await?; Ok(()) } } #[derive(Debug, thiserror::Error)] pub enum SubscribeError { #[error(transparent)] Database(#[from] sqlx::Error), #[error(transparent)] Vapid(#[from] vapid::repo::Error), #[error("subscription created with stale VAPID key")] StaleVapidKey(VerifyingKey), #[error("subscription already exists for endpoint")] // The endpoint URL is not included in the error, as it is a bearer credential in its own right // and we want to limit its proliferation. The only intended recipient of this message is the // client, which already knows the endpoint anyways and doesn't need us to tell them. Duplicate, } #[derive(Debug, thiserror::Error)] pub enum PushError { #[error(transparent)] Database(#[from] sqlx::Error), #[error(transparent)] Ecdsa(#[from] p256::ecdsa::Error), #[error(transparent)] Pkcs8(#[from] p256::pkcs8::Error), #[error(transparent)] WebPush(#[from] WebPushError), #[error("push message delivery failures: {0:?}")] Delivery(Vec), } impl From for PushError { fn from(error: vapid::repo::Error) -> Self { use vapid::repo::Error; match error { Error::Database(error) => error.into(), Error::Ecdsa(error) => error.into(), Error::Pkcs8(error) => error.into(), Error::WebPush(error) => error.into(), } } }