diff options
Diffstat (limited to 'src/push')
| -rw-r--r-- | src/push/app.rs | 138 | ||||
| -rw-r--r-- | src/push/handlers/echo.rs | 20 | ||||
| -rw-r--r-- | src/push/handlers/mod.rs | 79 | ||||
| -rw-r--r-- | src/push/handlers/register.rs | 40 | ||||
| -rw-r--r-- | src/push/handlers/unregister.rs | 16 | ||||
| -rw-r--r-- | src/push/id.rs | 27 | ||||
| -rw-r--r-- | src/push/mod.rs | 6 | ||||
| -rw-r--r-- | src/push/repo.rs | 81 | ||||
| -rw-r--r-- | src/push/vapid.rs | 28 |
9 files changed, 322 insertions, 113 deletions
diff --git a/src/push/app.rs b/src/push/app.rs new file mode 100644 index 0000000..2d6e15c --- /dev/null +++ b/src/push/app.rs @@ -0,0 +1,138 @@ +use sqlx::SqlitePool; +use web_push::{ + ContentEncoding, IsahcWebPushClient, PartialVapidSignatureBuilder, SubscriptionInfo, + WebPushClient, WebPushMessageBuilder, +}; + +use super::{Id, repo::Provider as _}; + +use crate::{ + db::NotFound as _, + user::{self, User}, +}; + +pub struct Push<'a> { + db: &'a SqlitePool, + vapid_public_key: &'a str, + vapid_signer: &'a PartialVapidSignatureBuilder, +} + +impl<'a> Push<'a> { + pub const fn new( + db: &'a SqlitePool, + vapid_public_key: &'a str, + vapid_signer: &'a PartialVapidSignatureBuilder, + ) -> Self { + Self { + db, + vapid_public_key, + vapid_signer, + } + } + + pub fn public_key(&self) -> &str { + self.vapid_public_key + } + + pub async fn register( + &self, + user: &User, + subscription: &SubscriptionInfo, + ) -> Result<Id, RegisterError> { + let mut tx = self.db.begin().await?; + let id = tx.subscriptions().create(user, subscription).await?; + tx.commit().await?; + + Ok(id) + } + + pub async fn echo( + &self, + user: &User, + subscription: &Id, + message: &str, + ) -> Result<(), EchoError> { + let mut tx = self.db.begin().await?; + let subscription = tx + .subscriptions() + .by_id(subscription) + .await + .not_found(|| EchoError::NotFound(subscription.clone()))?; + if subscription.user != user.id { + return Err(EchoError::NotSubscriber(subscription.id, user.id.clone())); + } + + tx.commit().await?; + + self.send(&subscription.info, message).await?; + + Ok(()) + } + + async fn send(&self, subscription: &SubscriptionInfo, message: &str) -> Result<(), EchoError> { + let sig_builder = self + .vapid_signer + .clone() + .add_sub_info(subscription) + .build()?; + + let payload = message.as_bytes(); + + let mut message_builder = WebPushMessageBuilder::new(subscription); + message_builder.set_payload(ContentEncoding::Aes128Gcm, payload); + message_builder.set_vapid_signature(sig_builder); + let message = message_builder.build()?; + + let client = IsahcWebPushClient::new()?; + client.send(message).await?; + + Ok(()) + } + + pub async fn unregister(&self, user: &User, subscription: &Id) -> Result<(), UnregisterError> { + let mut tx = self.db.begin().await?; + let subscription = tx + .subscriptions() + .by_id(subscription) + .await + .not_found(|| UnregisterError::NotFound(subscription.clone()))?; + if subscription.user != user.id { + return Err(UnregisterError::NotSubscriber( + subscription.id, + user.id.clone(), + )); + } + tx.subscriptions().delete(&subscription).await?; + tx.commit().await?; + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RegisterError { + #[error(transparent)] + Database(#[from] sqlx::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum EchoError { + #[error("subscription {0} not found")] + NotFound(Id), + #[error("user {1} is not the subscriber for subscription {0}")] + NotSubscriber(Id, user::Id), + #[error(transparent)] + WebPush(#[from] web_push::WebPushError), + #[error(transparent)] + Database(#[from] sqlx::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum UnregisterError { + #[error("subscription {0} not found")] + NotFound(Id), + #[error("user {1} is not the subscriber for subscription {0}")] + NotSubscriber(Id, user::Id), + #[error(transparent)] + Database(#[from] sqlx::Error), +} diff --git a/src/push/handlers/echo.rs b/src/push/handlers/echo.rs new file mode 100644 index 0000000..4b4de57 --- /dev/null +++ b/src/push/handlers/echo.rs @@ -0,0 +1,20 @@ +use axum::extract::{Json, State}; + +use crate::{app::App, push::Id, token::extract::Identity}; + +#[derive(serde::Deserialize)] +pub struct Request { + subscription: Id, + msg: String, +} + +pub async fn handler( + State(app): State<App>, + identity: Identity, + Json(request): Json<Request>, +) -> Result<(), crate::error::Internal> { + let Request { subscription, msg } = request; + app.push().echo(&identity.user, &subscription, &msg).await?; + + Ok(()) +} diff --git a/src/push/handlers/mod.rs b/src/push/handlers/mod.rs index e4a531b..90edaa7 100644 --- a/src/push/handlers/mod.rs +++ b/src/push/handlers/mod.rs @@ -1,74 +1,15 @@ -use std::env; +use axum::extract::State; -use axum::{ - extract::{Json}, -}; +use crate::app::App; -use web_push::{ - SubscriptionInfo, - VapidSignatureBuilder, - WebPushMessageBuilder, - ContentEncoding, - WebPushClient, - IsahcWebPushClient, -}; +mod echo; +mod register; +mod unregister; +pub use echo::handler as echo; +pub use register::handler as register; +pub use unregister::handler as unregister; -pub async fn vapid() -> String { - let vapid_public_key = env::var("VAPID_PUBLIC_KEY").unwrap_or_default(); - String::from(vapid_public_key) -} - - -pub async fn register() -> String { - String::from("OK") -} - - -pub async fn unregister() -> String { - String::from("OK") -} - -async fn push_message( - endpoint: String, - keys: Keys, - message: &String, -) -> Result<(), crate::error::Internal> { - let content = message.as_bytes(); - - let subscription_info = SubscriptionInfo::new(endpoint, keys.p256dh, keys.auth); - // This will need to come from the DB eventually: - let private_key = String::from(env::var("VAPID_PRIVATE_KEY").unwrap_or_default()); - let sig_builder = VapidSignatureBuilder::from_base64(&private_key, &subscription_info)?.build()?; - let mut builder = WebPushMessageBuilder::new(&subscription_info); - builder.set_payload(ContentEncoding::Aes128Gcm, content); - builder.set_vapid_signature(sig_builder); - let client = IsahcWebPushClient::new()?; - client.send(builder.build()?).await?; - - Ok(()) -} - - -#[axum::debug_handler] -pub async fn echo( - Json(payload): Json<PushPayload>, -) -> Result<(), crate::error::Internal> { - push_message(payload.endpoint, payload.keys, &payload.msg).await?; - - Ok(()) -} - - -#[derive(serde::Deserialize)] -pub struct Keys { - pub p256dh: String, - pub auth: String, -} - -#[derive(serde::Deserialize)] -pub struct PushPayload { - pub msg: String, - pub endpoint: String, - pub keys: Keys, +pub async fn vapid(State(app): State<App>) -> String { + app.push().public_key().to_owned() } diff --git a/src/push/handlers/register.rs b/src/push/handlers/register.rs new file mode 100644 index 0000000..201928b --- /dev/null +++ b/src/push/handlers/register.rs @@ -0,0 +1,40 @@ +use axum::extract::{Json, State}; +use web_push::SubscriptionInfo; + +use crate::{app::App, error::Internal, push::Id, token::extract::Identity}; + +#[derive(serde::Deserialize)] +pub struct Request { + endpoint: String, + p256dh: String, + auth: String, +} + +#[derive(serde::Serialize)] +pub struct Response { + id: Id, +} + +pub async fn handler( + State(app): State<App>, + identity: Identity, + Json(request): Json<Request>, +) -> Result<Json<Response>, Internal> { + let subscription = request.into(); + + let id = app.push().register(&identity.user, &subscription).await?; + + Ok(Json(Response { id })) +} + +impl From<Request> for SubscriptionInfo { + fn from(request: Request) -> Self { + let Request { + endpoint, + p256dh, + auth, + } = request; + let info = Self::new(endpoint, p256dh, auth); + info + } +} diff --git a/src/push/handlers/unregister.rs b/src/push/handlers/unregister.rs new file mode 100644 index 0000000..a00ee92 --- /dev/null +++ b/src/push/handlers/unregister.rs @@ -0,0 +1,16 @@ +use axum::{ + extract::{Path, State}, + http::StatusCode, +}; + +use crate::{app::App, error::Internal, push::Id, token::extract::Identity}; + +pub async fn handler( + State(app): State<App>, + identity: Identity, + Path(subscription): Path<Id>, +) -> Result<StatusCode, Internal> { + app.push().unregister(&identity.user, &subscription).await?; + + Ok(StatusCode::NO_CONTENT) +} diff --git a/src/push/id.rs b/src/push/id.rs new file mode 100644 index 0000000..b28d6ab --- /dev/null +++ b/src/push/id.rs @@ -0,0 +1,27 @@ +use std::fmt; + +use crate::id::Id as BaseId; + +// Stable identifier for a push subscription. Prefixed with `S`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From<BaseId> for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("S") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/push/mod.rs b/src/push/mod.rs index c3d4495..c32cb27 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1 +1,7 @@ +pub mod app; pub mod handlers; +mod id; +mod repo; +pub mod vapid; + +use id::Id; diff --git a/src/push/repo.rs b/src/push/repo.rs index 2d492ea..ddef706 100644 --- a/src/push/repo.rs +++ b/src/push/repo.rs @@ -1,9 +1,8 @@ use sqlx::{SqliteConnection, Transaction, sqlite::Sqlite}; +use web_push::SubscriptionInfo; -use super::{Subscription, Id}; -use crate::{ - user::{self, User}, -} +use super::Id; +use crate::user::{self, User}; pub trait Provider { fn subscriptions(&mut self) -> Subscriptions; @@ -21,74 +20,68 @@ impl Subscriptions<'_> { pub async fn create( &mut self, user: &User, - endpoint: &String, - key_p256dh: &String, - key_auth: &String, - expiration_time: &String, - ) -> Result<Subscription, sqlx::Error> { + info: &SubscriptionInfo, + ) -> Result<Id, sqlx::Error> { let id = Id::generate(); - let subscription = sqlx::query!( + sqlx::query!( r#" - insert into subscription - (id, user, endpoint, key_p256dh, key_auth, expiration_time) - values ($1, $2, $3, $4, $5, $6) - returning - id as "id: Id", - user as "user: user::Id", - endpoint as "endpoint: String", - key_p256dh as "key_p256dh: String", - key_auth as "key_auth: String", - expiration_time as "expiration_time: String" + insert into subscription (id, user, endpoint, key_p256dh, key_auth) + values ($1, $2, $3, $4, $5) "#, id, user.id, - endpoint, - key_p256dh, - key_auth, - expiration_time, + info.endpoint, + info.keys.p256dh, + info.keys.auth, ) - .fetch_one(&mut *self.0) + .execute(&mut *self.0) .await?; - Ok(subscription) + Ok(id) } - pub async fn for_user(&mut self, user: &User) -> Result<vec<Subscription>, sqlx::Error> { - let subscriptions = sqlx::query!( + pub async fn by_id(&mut self, id: &Id) -> Result<Subscription, sqlx::Error> { + let subscription = sqlx::query!( r#" select id as "id: Id", user as "user: user::Id", - endpoint as "endpoint: String", - key_p256dh as "key_p256dh: String", - key_auth as "key_auth: String", + endpoint, + key_p256dh, + key_auth from subscription - where user = $1 + where id = $1 "#, - user.id, + id, ) - .fetch_all(&mut *self.0) + .map(|row| Subscription { + id: row.id, + user: row.user, + info: SubscriptionInfo::new(row.endpoint, row.key_p256dh, row.key_auth), + }) + .fetch_one(&mut *self.0) .await?; - Ok(subscriptions) + Ok(subscription) } - pub async fn delete( - &mut self, - subscription: &Subscription, - deleted: &Instant, - ) -> Result<(), sqlx::Error> { - let id = subscription.id(); - + pub async fn delete(&mut self, subscription: &Subscription) -> Result<(), sqlx::Error> { sqlx::query!( r#" - delete from subscription where id = $1 + delete from subscription + where id = $1 "#, - id, + subscription.id, ) .execute(&mut *self.0) .await?; Ok(()) } } + +pub struct Subscription { + pub id: Id, + pub user: user::Id, + pub info: SubscriptionInfo, +} diff --git a/src/push/vapid.rs b/src/push/vapid.rs new file mode 100644 index 0000000..b13a00c --- /dev/null +++ b/src/push/vapid.rs @@ -0,0 +1,28 @@ +use std::fmt; + +use web_push::{PartialVapidSignatureBuilder, VapidSignatureBuilder, WebPushError}; + +#[derive(Clone)] +pub struct PrivateKey(String); + +impl PrivateKey { + pub fn as_signature_builder(&self) -> Result<PartialVapidSignatureBuilder, WebPushError> { + let Self(key) = self; + VapidSignatureBuilder::from_base64_no_sub(key) + } +} + +impl fmt::Debug for PrivateKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PrivateKey").field(&"********").finish() + } +} + +impl<S> From<S> for PrivateKey +where + S: Into<String>, +{ + fn from(value: S) -> Self { + Self(value.into()) + } +} |
