use sqlx::SqlitePool; use web_push::{ ContentEncoding, IsahcWebPushClient, PartialVapidSignatureBuilder, SubscriptionInfo, WebPushClient, WebPushMessageBuilder, }; use super::{Id, repo::Provider as _}; use crate::{ db::NotFound as _, user::{self, User}, }; pub struct Push<'a> { db: &'a SqlitePool, vapid_public_key: &'a str, vapid_signer: &'a PartialVapidSignatureBuilder, } impl<'a> Push<'a> { pub const fn new( db: &'a SqlitePool, vapid_public_key: &'a str, vapid_signer: &'a PartialVapidSignatureBuilder, ) -> Self { Self { db, vapid_public_key, vapid_signer, } } pub fn public_key(&self) -> &str { self.vapid_public_key } pub async fn register( &self, user: &User, subscription: &SubscriptionInfo, ) -> Result { let mut tx = self.db.begin().await?; let id = tx.subscriptions().create(user, subscription).await?; tx.commit().await?; Ok(id) } pub async fn broadcast( &self, message: &str, ) -> Result<(), EchoError> { let mut tx = self.db.begin().await?; let subscriptions = tx .subscriptions() .all() .await?; tx.commit().await?; for subscription in subscriptions { // We don't care if any of these error, for now. // Eventually, we should remove rows that cause certain error conditions. println!("Sending to {:#?}", subscription.info.endpoint); self.send(&subscription.info, message).await.unwrap_or_else(|err| { println!("Error with {:#?}: {}", subscription.info.endpoint, err); }) } Ok(()) } pub async fn echo( &self, user: &User, endpoint: &String, message: &str, ) -> Result<(), EchoError> { let mut tx = self.db.begin().await?; let subscription = tx .subscriptions() .by_endpoint(endpoint) .await .not_found(|| EchoError::NotFound(endpoint.clone()))?; if subscription.user != user.id { return Err(EchoError::NotSubscriber(subscription.id, user.id.clone())); } tx.commit().await?; self.send(&subscription.info, message).await?; Ok(()) } async fn send(&self, subscription: &SubscriptionInfo, message: &str) -> Result<(), EchoError> { let sig_builder = self .vapid_signer .clone() .add_sub_info(subscription) .build()?; let payload = message.as_bytes(); let mut message_builder = WebPushMessageBuilder::new(subscription); message_builder.set_payload(ContentEncoding::Aes128Gcm, payload); message_builder.set_vapid_signature(sig_builder); let message = message_builder.build()?; let client = IsahcWebPushClient::new()?; client.send(message).await?; Ok(()) } pub async fn unregister(&self, user: &User, endpoint: &String) -> Result<(), UnregisterError> { let mut tx = self.db.begin().await?; let subscription = tx .subscriptions() .by_endpoint(endpoint) .await .not_found(|| UnregisterError::NotFound(endpoint.clone()))?; if subscription.user != user.id { return Err(UnregisterError::NotSubscriber( subscription.id, user.id.clone(), )); } tx.subscriptions().delete(&subscription).await?; tx.commit().await?; Ok(()) } } #[derive(Debug, thiserror::Error)] pub enum RegisterError { #[error(transparent)] Database(#[from] sqlx::Error), } #[derive(Debug, thiserror::Error)] pub enum EchoError { #[error("subscription {0} not found")] NotFound(String), #[error("user {1} is not the subscriber for subscription {0}")] NotSubscriber(Id, user::Id), #[error(transparent)] WebPush(#[from] web_push::WebPushError), #[error(transparent)] Database(#[from] sqlx::Error), } #[derive(Debug, thiserror::Error)] pub enum UnregisterError { #[error("subscription {0} not found")] NotFound(String), #[error("user {1} is not the subscriber for subscription {0}")] NotSubscriber(Id, user::Id), #[error(transparent)] Database(#[from] sqlx::Error), }