use futures::future::join_all;
use itertools::Itertools as _;
use p256::ecdsa::VerifyingKey;
use sqlx::SqlitePool;
use web_push::{
ContentEncoding, PartialVapidSignatureBuilder, SubscriptionInfo, WebPushClient, WebPushError,
WebPushMessage, WebPushMessageBuilder,
};
use super::repo::Provider as _;
use crate::{login::Login, token::extract::Identity, vapid, vapid::repo::Provider as _};
pub struct Push
{
db: SqlitePool,
webpush: P,
}
impl
Push
{
pub const fn new(db: SqlitePool, webpush: P) -> Self {
Self { db, webpush }
}
pub async fn subscribe(
&self,
subscriber: &Identity,
subscription: &SubscriptionInfo,
vapid: &VerifyingKey,
) -> Result<(), SubscribeError> {
let mut tx = self.db.begin().await?;
let current = tx.vapid().current().await?;
if vapid != ¤t.key {
return Err(SubscribeError::StaleVapidKey(current.key));
}
match tx.push().create(&subscriber.token, subscription).await {
Ok(()) => (),
Err(err) => {
if let Some(err) = err.as_database_error()
&& err.is_unique_violation()
{
let current = tx
.push()
.by_endpoint(&subscriber.login, &subscription.endpoint)
.await?;
// If we already have a subscription for this endpoint, with _different_
// parameters, then this is a client error. They shouldn't reuse endpoint URLs,
// per the various RFCs.
//
// However, if we have a subscription for this endpoint with the same parameters
// then we accept it and silently do nothing. This may happen if, for example,
// the subscribe request is retried due to a network interruption where it's
// not clear whether the original request succeeded.
if ¤t != subscription {
return Err(SubscribeError::Duplicate);
}
} else {
return Err(SubscribeError::Database(err));
}
}
}
tx.commit().await?;
Ok(())
}
}
impl
Push
where
P: WebPushClient,
{
fn prepare_ping(
signer: &PartialVapidSignatureBuilder,
subscription: &SubscriptionInfo,
) -> Result {
let signature = signer.clone().add_sub_info(subscription).build()?;
let payload = "ping".as_bytes();
let mut message = WebPushMessageBuilder::new(subscription);
message.set_payload(ContentEncoding::Aes128Gcm, payload);
message.set_vapid_signature(signature);
let message = message.build()?;
Ok(message)
}
pub async fn ping(&self, recipient: &Login) -> Result<(), PushError> {
let mut tx = self.db.begin().await?;
let signer = tx.vapid().signer().await?;
let subscriptions = tx.push().by_login(recipient).await?;
let pings: Vec<_> = subscriptions
.into_iter()
.map(|sub| Self::prepare_ping(&signer, &sub).map(|message| (sub, message)))
.try_collect()?;
let deliveries = pings
.into_iter()
.map(async |(sub, message)| (sub, self.webpush.send(message).await));
let failures: Vec<_> = join_all(deliveries)
.await
.into_iter()
.filter_map(|(sub, result)| result.err().map(|err| (sub, err)))
.collect();
if !failures.is_empty() {
for (sub, err) in &failures {
match err {
// I _think_ this is the complete set of permanent failures. See
// for a complete
// list.
WebPushError::Unauthorized(_)
| WebPushError::InvalidUri
| WebPushError::EndpointNotValid(_)
| WebPushError::EndpointNotFound(_)
| WebPushError::InvalidCryptoKeys
| WebPushError::MissingCryptoKeys => {
tx.push().unsubscribe(sub).await?;
}
_ => (),
}
}
return Err(PushError::Delivery(
failures.into_iter().map(|(_, err)| err).collect(),
));
}
tx.commit().await?;
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum SubscribeError {
#[error(transparent)]
Database(#[from] sqlx::Error),
#[error(transparent)]
Vapid(#[from] vapid::repo::Error),
#[error("subscription created with stale VAPID key")]
StaleVapidKey(VerifyingKey),
#[error("subscription already exists for endpoint")]
// The endpoint URL is not included in the error, as it is a bearer credential in its own right
// and we want to limit its proliferation. The only intended recipient of this message is the
// client, which already knows the endpoint anyways and doesn't need us to tell them.
Duplicate,
}
#[derive(Debug, thiserror::Error)]
pub enum PushError {
#[error(transparent)]
Database(#[from] sqlx::Error),
#[error(transparent)]
Ecdsa(#[from] p256::ecdsa::Error),
#[error(transparent)]
Pkcs8(#[from] p256::pkcs8::Error),
#[error(transparent)]
WebPush(#[from] WebPushError),
#[error("push message delivery failures: {0:?}")]
Delivery(Vec),
}
impl From for PushError {
fn from(error: vapid::repo::Error) -> Self {
use vapid::repo::Error;
match error {
Error::Database(error) => error.into(),
Error::Ecdsa(error) => error.into(),
Error::Pkcs8(error) => error.into(),
Error::WebPush(error) => error.into(),
}
}
}