diff options
| author | Owen Jacobson <owen@grimoire.ca> | 2025-11-08 16:28:10 -0500 |
|---|---|---|
| committer | Owen Jacobson <owen@grimoire.ca> | 2025-11-08 16:28:10 -0500 |
| commit | fc6914831743f6d683c59adb367479defe6f8b3a (patch) | |
| tree | 5b997adac55f47b52f30022013b8ec3b2c10bcc5 /src | |
| parent | 0ef69c7d256380e660edc45ace7f1d6151226340 (diff) | |
| parent | 6bab5b4405c9adafb2ce76540595a62eea80acc0 (diff) | |
Integrate the prototype push notification support.
We're going to move forwards with this for now, as low-utility as it is, so that we can more easily iterate on it in a real-world environment (hi.grimoire.ca).
Diffstat (limited to 'src')
38 files changed, 1635 insertions, 37 deletions
@@ -10,30 +10,34 @@ use crate::{ invite::app::Invites, login::app::Logins, message::app::Messages, + push::app::Push, setup::app::Setup, token::{self, app::Tokens}, + vapid::app::Vapid, }; #[derive(Clone)] -pub struct App { +pub struct App<P> { db: SqlitePool, + webpush: P, events: event::Broadcaster, token_events: token::Broadcaster, } -impl App { - pub fn from(db: SqlitePool) -> Self { +impl<P> App<P> { + pub fn from(db: SqlitePool, webpush: P) -> Self { let events = event::Broadcaster::default(); let token_events = token::Broadcaster::default(); Self { db, + webpush, events, token_events, } } } -impl App { +impl<P> App<P> { pub fn boot(&self) -> Boot { Boot::new(self.db.clone()) } @@ -58,6 +62,13 @@ impl App { Messages::new(self.db.clone(), self.events.clone()) } + pub fn push(&self) -> Push<P> + where + P: Clone, + { + Push::new(self.db.clone(), self.webpush.clone()) + } + pub fn setup(&self) -> Setup { Setup::new(self.db.clone(), self.events.clone()) } @@ -70,46 +81,70 @@ impl App { pub fn users(&self) -> Users { Users::new(self.db.clone(), self.events.clone()) } + + pub fn vapid(&self) -> Vapid { + Vapid::new(self.db.clone(), self.events.clone()) + } + + #[cfg(test)] + pub fn webpush(&self) -> &P { + &self.webpush + } } -impl FromRef<App> for Boot { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Boot { + fn from_ref(app: &App<P>) -> Self { app.boot() } } -impl FromRef<App> for Conversations { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Conversations { + fn from_ref(app: &App<P>) -> Self { app.conversations() } } -impl FromRef<App> for Invites { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Invites { + fn from_ref(app: &App<P>) -> Self { app.invites() } } -impl FromRef<App> for Logins { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Logins { + fn from_ref(app: &App<P>) -> Self { app.logins() } } -impl FromRef<App> for Messages { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Messages { + fn from_ref(app: &App<P>) -> Self { app.messages() } } -impl FromRef<App> for Setup { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Push<P> +where + P: Clone, +{ + fn from_ref(app: &App<P>) -> Self { + app.push() + } +} + +impl<P> FromRef<App<P>> for Setup { + fn from_ref(app: &App<P>) -> Self { app.setup() } } -impl FromRef<App> for Tokens { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Tokens { + fn from_ref(app: &App<P>) -> Self { app.tokens() } } + +impl<P> FromRef<App<P>> for Vapid { + fn from_ref(app: &App<P>) -> Self { + app.vapid() + } +} diff --git a/src/boot/app.rs b/src/boot/app.rs index 840243e..1ca8adb 100644 --- a/src/boot/app.rs +++ b/src/boot/app.rs @@ -4,10 +4,12 @@ use sqlx::sqlite::SqlitePool; use super::Snapshot; use crate::{ conversation::{self, repo::Provider as _}, + db::NotFound, event::{Event, Sequence, repo::Provider as _}, message::{self, repo::Provider as _}, name, user::{self, repo::Provider as _}, + vapid::{self, repo::Provider as _}, }; pub struct Boot { @@ -26,6 +28,7 @@ impl Boot { let users = tx.users().all(resume_point).await?; let conversations = tx.conversations().all(resume_point).await?; let messages = tx.messages().all(resume_point).await?; + let vapid = tx.vapid().current().await.optional()?; tx.commit().await?; @@ -50,9 +53,16 @@ impl Boot { .filter(Sequence::up_to(resume_point)) .map(Event::from); + let vapid_events = vapid + .iter() + .flat_map(vapid::History::events) + .filter(Sequence::up_to(resume_point)) + .map(Event::from); + let events = user_events .merge_by(conversation_events, Sequence::merge) .merge_by(message_events, Sequence::merge) + .merge_by(vapid_events, Sequence::merge) .collect(); Ok(Snapshot { @@ -65,8 +75,11 @@ impl Boot { #[derive(Debug, thiserror::Error)] #[error(transparent)] pub enum Error { - Name(#[from] name::Error), Database(#[from] sqlx::Error), + Name(#[from] name::Error), + Ecdsa(#[from] p256::ecdsa::Error), + Pkcs8(#[from] p256::pkcs8::Error), + WebPush(#[from] web_push::WebPushError), } impl From<user::repo::LoadError> for Error { @@ -88,3 +101,15 @@ impl From<conversation::repo::LoadError> for Error { } } } + +impl From<vapid::repo::Error> for Error { + fn from(error: vapid::repo::Error) -> Self { + use vapid::repo::Error; + match error { + Error::Database(error) => error.into(), + Error::Ecdsa(error) => error.into(), + Error::Pkcs8(error) => error.into(), + Error::WebPush(error) => error.into(), + } + } +} diff --git a/src/boot/handlers/boot/test.rs b/src/boot/handlers/boot/test.rs index a9891eb..f192478 100644 --- a/src/boot/handlers/boot/test.rs +++ b/src/boot/handlers/boot/test.rs @@ -81,6 +81,74 @@ async fn includes_messages() { } #[tokio::test] +async fn includes_vapid_key() { + let app = fixtures::scratch_app().await; + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("key rotation always succeeds"); + + let viewer = fixtures::identity::fictitious(); + let response = super::handler(State(app.boot()), viewer) + .await + .expect("boot always succeeds"); + + response + .snapshot + .events + .into_iter() + .filter_map(fixtures::event::vapid) + .filter_map(fixtures::event::vapid::changed) + .exactly_one() + .expect("only one vapid key has been created"); +} + +#[tokio::test] +async fn includes_only_latest_vapid_key() { + let app = fixtures::scratch_app().await; + + app.vapid() + .refresh_key(&fixtures::ancient()) + .await + .expect("key rotation always succeeds"); + + let viewer = fixtures::identity::fictitious(); + let response = super::handler(State(app.boot()), viewer.clone()) + .await + .expect("boot always succeeds"); + + let original_key = response + .snapshot + .events + .into_iter() + .filter_map(fixtures::event::vapid) + .filter_map(fixtures::event::vapid::changed) + .exactly_one() + .expect("only one vapid key has been created"); + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("key rotation always succeeds"); + + let response = super::handler(State(app.boot()), viewer) + .await + .expect("boot always succeeds"); + + let rotated_key = response + .snapshot + .events + .into_iter() + .filter_map(fixtures::event::vapid) + .filter_map(fixtures::event::vapid::changed) + .exactly_one() + .expect("only one vapid key should be returned"); + + assert_ne!(original_key, rotated_key); +} + +#[tokio::test] async fn includes_expired_messages() { let app = fixtures::scratch_app().await; let sender = fixtures::user::create(&app, &fixtures::ancient()).await; @@ -10,9 +10,10 @@ use axum::{ middleware, response::{IntoResponse, Response}, }; -use clap::{CommandFactory, Parser}; +use clap::{CommandFactory, Parser, Subcommand}; use sqlx::sqlite::SqlitePool; use tokio::net; +use web_push::{IsahcWebPushClient, WebPushClient}; use crate::{ app::App, @@ -65,6 +66,15 @@ pub struct Args { /// upgrades #[arg(short = 'D', long, env, default_value = "sqlite://pilcrow.db.backup")] backup_database_url: String, + + #[command(subcommand)] + command: Option<Command>, +} + +#[derive(Subcommand)] +enum Command { + /// Immediately rotate the server's VAPID (Web Push) application key. + RotateVapidKey, } impl Args { @@ -88,7 +98,21 @@ impl Args { self.umask.set(); let pool = self.pool().await?; - let app = App::from(pool); + let webpush = IsahcWebPushClient::new()?; + let app = App::from(pool, webpush); + + match self.command { + None => self.serve(app).await?, + Some(Command::RotateVapidKey) => app.vapid().rotate_key().await?, + } + + Result::<_, Error>::Ok(()) + } + + async fn serve<P>(self, app: App<P>) -> Result<(), Error> + where + P: WebPushClient + Clone + Send + Sync + 'static, + { let app = routes::routes(&app) .route_layer(middleware::from_fn(clock::middleware)) .route_layer(middleware::map_response(Self::server_info())) @@ -101,7 +125,7 @@ impl Args { println!("{started_msg}"); serve.await?; - Result::<_, Error>::Ok(()) + Ok(()) } async fn listener(&self) -> io::Result<net::TcpListener> { @@ -140,5 +164,7 @@ fn started_msg(listener: &net::TcpListener) -> io::Result<String> { enum Error { Io(#[from] io::Error), Database(#[from] db::Error), + Sqlx(#[from] sqlx::Error), Umask(#[from] umask::Error), + Webpush(#[from] web_push::WebPushError), } diff --git a/src/event/app.rs b/src/event/app.rs index 8fa760a..e422de9 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -8,9 +8,12 @@ use sqlx::sqlite::SqlitePool; use super::{Event, Sequence, Sequenced, broadcaster::Broadcaster}; use crate::{ conversation::{self, repo::Provider as _}, + db::NotFound, message::{self, repo::Provider as _}, name, user::{self, repo::Provider as _}, + vapid, + vapid::repo::Provider as _, }; pub struct Events { @@ -57,9 +60,17 @@ impl Events { .filter(Sequence::after(resume_at)) .map(Event::from); + let vapid = tx.vapid().current().await.optional()?; + let vapid_events = vapid + .iter() + .flat_map(vapid::History::events) + .filter(Sequence::after(resume_at)) + .map(Event::from); + let replay_events = user_events .merge_by(conversation_events, Sequence::merge) .merge_by(message_events, Sequence::merge) + .merge_by(vapid_events, Sequence::merge) .collect::<Vec<_>>(); let resume_live_at = replay_events.last().map_or(resume_at, Sequenced::sequence); @@ -86,6 +97,9 @@ impl Events { pub enum Error { Database(#[from] sqlx::Error), Name(#[from] name::Error), + Ecdsa(#[from] p256::ecdsa::Error), + Pkcs8(#[from] p256::pkcs8::Error), + WebPush(#[from] web_push::WebPushError), } impl From<user::repo::LoadError> for Error { @@ -107,3 +121,15 @@ impl From<conversation::repo::LoadError> for Error { } } } + +impl From<vapid::repo::Error> for Error { + fn from(error: vapid::repo::Error) -> Self { + use vapid::repo::Error; + match error { + Error::Database(error) => error.into(), + Error::Ecdsa(error) => error.into(), + Error::Pkcs8(error) => error.into(), + Error::WebPush(error) => error.into(), + } + } +} diff --git a/src/event/handlers/stream/mod.rs b/src/event/handlers/stream/mod.rs index 63bfff3..8b89c31 100644 --- a/src/event/handlers/stream/mod.rs +++ b/src/event/handlers/stream/mod.rs @@ -18,8 +18,8 @@ use crate::{ #[cfg(test)] mod test; -pub async fn handler( - State(app): State<App>, +pub async fn handler<P>( + State(app): State<App<P>>, identity: Identity, last_event_id: Option<LastEventId<Sequence>>, Query(query): Query<QueryParams>, diff --git a/src/event/handlers/stream/test/mod.rs b/src/event/handlers/stream/test/mod.rs index 3bc634f..c3a6ce6 100644 --- a/src/event/handlers/stream/test/mod.rs +++ b/src/event/handlers/stream/test/mod.rs @@ -4,5 +4,6 @@ mod message; mod resume; mod setup; mod token; +mod vapid; use super::{QueryParams, Response, handler}; diff --git a/src/event/handlers/stream/test/vapid.rs b/src/event/handlers/stream/test/vapid.rs new file mode 100644 index 0000000..dbc3929 --- /dev/null +++ b/src/event/handlers/stream/test/vapid.rs @@ -0,0 +1,111 @@ +use axum::extract::State; +use axum_extra::extract::Query; +use futures::StreamExt as _; + +use crate::test::{fixtures, fixtures::future::Expect as _}; + +#[tokio::test] +async fn live_vapid_key_changes() { + // Set up the context + let app = fixtures::scratch_app().await; + let resume_point = fixtures::boot::resume_point(&app).await; + + // Subscribe to events + + let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; + let super::Response(events) = super::handler( + State(app.clone()), + subscriber, + None, + Query(super::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); + + // Rotate the VAPID key + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the vapid key always succeeds"); + + // Verify that there's a key rotation event + + events + .filter_map(fixtures::event::stream::vapid) + .filter_map(fixtures::event::stream::vapid::changed) + .next() + .expect_some("a vapid key change event is sent") + .await; +} + +#[tokio::test] +async fn stored_vapid_key_changes() { + // Set up the context + let app = fixtures::scratch_app().await; + let resume_point = fixtures::boot::resume_point(&app).await; + + // Rotate the VAPID key + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the vapid key always succeeds"); + + // Subscribe to events + + let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; + let super::Response(events) = super::handler( + State(app.clone()), + subscriber, + None, + Query(super::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); + + // Verify that there's a key rotation event + + events + .filter_map(fixtures::event::stream::vapid) + .filter_map(fixtures::event::stream::vapid::changed) + .next() + .expect_some("a vapid key change event is sent") + .await; +} + +#[tokio::test] +async fn no_past_vapid_key_changes() { + // Set up the context + let app = fixtures::scratch_app().await; + + // Rotate the VAPID key + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the vapid key always succeeds"); + + // Subscribe to events + + let resume_point = fixtures::boot::resume_point(&app).await; + + let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; + let super::Response(events) = super::handler( + State(app.clone()), + subscriber, + None, + Query(super::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); + + // Verify that there's a key rotation event + + events + .filter_map(fixtures::event::stream::vapid) + .filter_map(fixtures::event::stream::vapid::changed) + .next() + .expect_wait("a vapid key change event is not sent") + .await; +} diff --git a/src/event/mod.rs b/src/event/mod.rs index f41dc9c..83b0ce7 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -2,7 +2,7 @@ use std::time::Duration; use axum::response::sse::{self, KeepAlive}; -use crate::{conversation, message, user}; +use crate::{conversation, message, user, vapid}; pub mod app; mod broadcaster; @@ -22,6 +22,7 @@ pub enum Event { User(user::Event), Conversation(conversation::Event), Message(message::Event), + Vapid(vapid::Event), } // Serialized representation is intended to look like the serialized representation of `Event`, @@ -40,6 +41,7 @@ impl Sequenced for Event { Self::User(event) => event.instant(), Self::Conversation(event) => event.instant(), Self::Message(event) => event.instant(), + Self::Vapid(event) => event.instant(), } } } @@ -62,6 +64,12 @@ impl From<message::Event> for Event { } } +impl From<vapid::Event> for Event { + fn from(event: vapid::Event) -> Self { + Self::Vapid(event) + } +} + impl Heartbeat { // The following values are a first-rough-guess attempt to balance noticing connection problems // quickly with managing the (modest) costs of delivering and processing heartbeats. Feel diff --git a/src/expire.rs b/src/expire.rs index 4177a53..c3b0117 100644 --- a/src/expire.rs +++ b/src/expire.rs @@ -7,8 +7,8 @@ use axum::{ use crate::{app::App, clock::RequestedAt, error::Internal}; // Expires messages and conversations before each request. -pub async fn middleware( - State(app): State<App>, +pub async fn middleware<P>( + State(app): State<App<P>>, RequestedAt(expired_at): RequestedAt, req: Request, next: Next, @@ -20,6 +20,7 @@ mod message; mod name; mod normalize; mod password; +mod push; mod routes; mod setup; #[cfg(test)] @@ -28,3 +29,4 @@ mod token; mod ui; mod umask; mod user; +mod vapid; diff --git a/src/login/app.rs b/src/login/app.rs index a2f9636..8cc8cd0 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -6,6 +6,7 @@ use crate::{ login::{self, Login, repo::Provider as _}, name::{self, Name}, password::Password, + push::repo::Provider as _, token::{Broadcaster, Event as TokenEvent, Secret, Token, repo::Provider as _}, }; @@ -76,6 +77,7 @@ impl Logins { let mut tx = self.db.begin().await?; tx.logins().set_password(&login, &to_hash).await?; + tx.push().unsubscribe_login(&login).await?; let revoked = tx.tokens().revoke_all(&login).await?; tx.tokens().create(&token, &secret).await?; tx.commit().await?; diff --git a/src/push/app.rs b/src/push/app.rs new file mode 100644 index 0000000..56b9a02 --- /dev/null +++ b/src/push/app.rs @@ -0,0 +1,178 @@ +use futures::future::join_all; +use itertools::Itertools as _; +use p256::ecdsa::VerifyingKey; +use sqlx::SqlitePool; +use web_push::{ + ContentEncoding, PartialVapidSignatureBuilder, SubscriptionInfo, WebPushClient, WebPushError, + WebPushMessage, WebPushMessageBuilder, +}; + +use super::repo::Provider as _; +use crate::{login::Login, token::extract::Identity, vapid, vapid::repo::Provider as _}; + +pub struct Push<P> { + db: SqlitePool, + webpush: P, +} + +impl<P> Push<P> { + pub const fn new(db: SqlitePool, webpush: P) -> Self { + Self { db, webpush } + } + + pub async fn subscribe( + &self, + subscriber: &Identity, + subscription: &SubscriptionInfo, + vapid: &VerifyingKey, + ) -> Result<(), SubscribeError> { + let mut tx = self.db.begin().await?; + + let current = tx.vapid().current().await?; + if vapid != ¤t.key { + return Err(SubscribeError::StaleVapidKey(current.key)); + } + + match tx.push().create(&subscriber.token, subscription).await { + Ok(()) => (), + Err(err) => { + if let Some(err) = err.as_database_error() + && err.is_unique_violation() + { + let current = tx + .push() + .by_endpoint(&subscriber.login, &subscription.endpoint) + .await?; + // If we already have a subscription for this endpoint, with _different_ + // parameters, then this is a client error. They shouldn't reuse endpoint URLs, + // per the various RFCs. + // + // However, if we have a subscription for this endpoint with the same parameters + // then we accept it and silently do nothing. This may happen if, for example, + // the subscribe request is retried due to a network interruption where it's + // not clear whether the original request succeeded. + if ¤t != subscription { + return Err(SubscribeError::Duplicate); + } + } else { + return Err(SubscribeError::Database(err)); + } + } + } + + tx.commit().await?; + + Ok(()) + } +} + +impl<P> Push<P> +where + P: WebPushClient, +{ + fn prepare_ping( + signer: &PartialVapidSignatureBuilder, + subscription: &SubscriptionInfo, + ) -> Result<WebPushMessage, WebPushError> { + let signature = signer.clone().add_sub_info(subscription).build()?; + + let payload = "ping".as_bytes(); + + let mut message = WebPushMessageBuilder::new(subscription); + message.set_payload(ContentEncoding::Aes128Gcm, payload); + message.set_vapid_signature(signature); + let message = message.build()?; + + Ok(message) + } + + pub async fn ping(&self, recipient: &Login) -> Result<(), PushError> { + let mut tx = self.db.begin().await?; + + let signer = tx.vapid().signer().await?; + let subscriptions = tx.push().by_login(recipient).await?; + + let pings: Vec<_> = subscriptions + .into_iter() + .map(|sub| Self::prepare_ping(&signer, &sub).map(|message| (sub, message))) + .try_collect()?; + + let deliveries = pings + .into_iter() + .map(async |(sub, message)| (sub, self.webpush.send(message).await)); + + let failures: Vec<_> = join_all(deliveries) + .await + .into_iter() + .filter_map(|(sub, result)| result.err().map(|err| (sub, err))) + .collect(); + + if !failures.is_empty() { + for (sub, err) in &failures { + match err { + // I _think_ this is the complete set of permanent failures. See + // <https://docs.rs/web-push/latest/web_push/enum.WebPushError.html> for a complete + // list. + WebPushError::Unauthorized(_) + | WebPushError::InvalidUri + | WebPushError::EndpointNotValid(_) + | WebPushError::EndpointNotFound(_) + | WebPushError::InvalidCryptoKeys + | WebPushError::MissingCryptoKeys => { + tx.push().unsubscribe(sub).await?; + } + _ => (), + } + } + + return Err(PushError::Delivery( + failures.into_iter().map(|(_, err)| err).collect(), + )); + } + + tx.commit().await?; + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum SubscribeError { + #[error(transparent)] + Database(#[from] sqlx::Error), + #[error(transparent)] + Vapid(#[from] vapid::repo::Error), + #[error("subscription created with stale VAPID key")] + StaleVapidKey(VerifyingKey), + #[error("subscription already exists for endpoint")] + // The endpoint URL is not included in the error, as it is a bearer credential in its own right + // and we want to limit its proliferation. The only intended recipient of this message is the + // client, which already knows the endpoint anyways and doesn't need us to tell them. + Duplicate, +} + +#[derive(Debug, thiserror::Error)] +pub enum PushError { + #[error(transparent)] + Database(#[from] sqlx::Error), + #[error(transparent)] + Ecdsa(#[from] p256::ecdsa::Error), + #[error(transparent)] + Pkcs8(#[from] p256::pkcs8::Error), + #[error(transparent)] + WebPush(#[from] WebPushError), + #[error("push message delivery failures: {0:?}")] + Delivery(Vec<WebPushError>), +} + +impl From<vapid::repo::Error> for PushError { + fn from(error: vapid::repo::Error) -> Self { + use vapid::repo::Error; + match error { + Error::Database(error) => error.into(), + Error::Ecdsa(error) => error.into(), + Error::Pkcs8(error) => error.into(), + Error::WebPush(error) => error.into(), + } + } +} diff --git a/src/push/handlers/mod.rs b/src/push/handlers/mod.rs new file mode 100644 index 0000000..bb58774 --- /dev/null +++ b/src/push/handlers/mod.rs @@ -0,0 +1,5 @@ +mod ping; +mod subscribe; + +pub use ping::handler as ping; +pub use subscribe::handler as subscribe; diff --git a/src/push/handlers/ping/mod.rs b/src/push/handlers/ping/mod.rs new file mode 100644 index 0000000..db828fa --- /dev/null +++ b/src/push/handlers/ping/mod.rs @@ -0,0 +1,23 @@ +use axum::{Json, extract::State, http::StatusCode}; +use web_push::WebPushClient; + +use crate::{error::Internal, push::app::Push, token::extract::Identity}; + +#[cfg(test)] +mod test; + +#[derive(serde::Deserialize)] +pub struct Request {} + +pub async fn handler<P>( + State(push): State<Push<P>>, + identity: Identity, + Json(_): Json<Request>, +) -> Result<StatusCode, Internal> +where + P: WebPushClient, +{ + push.ping(&identity.login).await?; + + Ok(StatusCode::ACCEPTED) +} diff --git a/src/push/handlers/ping/test.rs b/src/push/handlers/ping/test.rs new file mode 100644 index 0000000..5725131 --- /dev/null +++ b/src/push/handlers/ping/test.rs @@ -0,0 +1,40 @@ +use axum::{ + extract::{Json, State}, + http::StatusCode, +}; + +use crate::test::fixtures; + +#[tokio::test] +async fn ping_without_subscriptions() { + let app = fixtures::scratch_app().await; + + let recipient = fixtures::identity::create(&app, &fixtures::now()).await; + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the VAPID key always succeeds"); + + let response = super::handler(State(app.push()), recipient, Json(super::Request {})) + .await + .expect("sending a ping with no subscriptions always succeeds"); + + assert_eq!(StatusCode::ACCEPTED, response); + + assert!(app.webpush().sent().is_empty()); +} + +// More complete testing requires that we figure out how to generate working p256 ECDH keys for +// testing _with_, as `web_push` will actually parse and use those keys even if push messages are +// ultimately never serialized or sent over HTTP. +// +// Tests that are missing: +// +// * Verify that subscribing and sending a ping causes a ping to be delivered to that subscription. +// * Verify that two subscriptions both get pings. +// * Verify that other users' subscriptions are not pinged. +// * Verify that a ping that causes a permanent error causes the subscription to be deleted. +// * Verify that a ping that causes a non-permanent error does not cause the subscription to be +// deleted. +// * Verify that a failure on one subscription doesn't affect delivery on other subscriptions. diff --git a/src/push/handlers/subscribe/mod.rs b/src/push/handlers/subscribe/mod.rs new file mode 100644 index 0000000..a1a5899 --- /dev/null +++ b/src/push/handlers/subscribe/mod.rs @@ -0,0 +1,94 @@ +use axum::{ + extract::{Json, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use p256::ecdsa::VerifyingKey; +use web_push::SubscriptionInfo; + +use crate::{ + error::Internal, + push::{app, app::Push}, + token::extract::Identity, +}; + +#[cfg(test)] +mod test; + +#[derive(Clone, serde::Deserialize)] +pub struct Request { + subscription: Subscription, + #[serde(with = "crate::vapid::ser::key")] + vapid: VerifyingKey, +} + +// This structure is described in <https://w3c.github.io/push-api/#dom-pushsubscription-tojson>. +#[derive(Clone, serde::Deserialize)] +pub struct Subscription { + endpoint: String, + keys: Keys, +} + +// This structure is described in <https://w3c.github.io/push-api/#dom-pushsubscription-tojson>. +#[derive(Clone, serde::Deserialize)] +pub struct Keys { + p256dh: String, + auth: String, +} + +pub async fn handler<P>( + State(push): State<Push<P>>, + identity: Identity, + Json(request): Json<Request>, +) -> Result<StatusCode, Error> { + let Request { + subscription, + vapid, + } = request; + + push.subscribe(&identity, &subscription.into(), &vapid) + .await?; + + Ok(StatusCode::CREATED) +} + +impl From<Subscription> for SubscriptionInfo { + fn from(request: Subscription) -> Self { + let Subscription { + endpoint, + keys: Keys { p256dh, auth }, + } = request; + SubscriptionInfo::new(endpoint, p256dh, auth) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct Error(#[from] app::SubscribeError); + +impl IntoResponse for Error { + fn into_response(self) -> Response { + let Self(err) = self; + + match err { + app::SubscribeError::StaleVapidKey(key) => { + let body = StaleVapidKey { + message: err.to_string(), + key, + }; + (StatusCode::BAD_REQUEST, Json(body)).into_response() + } + app::SubscribeError::Duplicate => { + (StatusCode::CONFLICT, err.to_string()).into_response() + } + other => Internal::from(other).into_response(), + } + } +} + +#[derive(serde::Serialize)] +struct StaleVapidKey { + message: String, + #[serde(with = "crate::vapid::ser::key")] + key: VerifyingKey, +} diff --git a/src/push/handlers/subscribe/test.rs b/src/push/handlers/subscribe/test.rs new file mode 100644 index 0000000..b72624d --- /dev/null +++ b/src/push/handlers/subscribe/test.rs @@ -0,0 +1,236 @@ +use axum::{ + extract::{Json, State}, + http::StatusCode, +}; + +use crate::{ + push::app::SubscribeError, + test::{fixtures, fixtures::event}, +}; + +#[tokio::test] +async fn accepts_new_subscription() { + let app = fixtures::scratch_app().await; + let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; + + // Issue a VAPID key. + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the VAPID key always succeeds"); + + // Find out what that VAPID key is. + + let boot = app.boot().snapshot().await.expect("boot always succeeds"); + let vapid = boot + .events + .into_iter() + .filter_map(event::vapid) + .filter_map(event::vapid::changed) + .next_back() + .expect("the application will have a vapid key after a refresh"); + + // Create a dummy subscription with that key. + + let request = super::Request { + subscription: super::Subscription { + endpoint: String::from("https://push.example.com/endpoint"), + keys: super::Keys { + p256dh: String::from("test-p256dh-value"), + auth: String::from("test-auth-value"), + }, + }, + vapid: vapid.key, + }; + let response = super::handler(State(app.push()), subscriber, Json(request)) + .await + .expect("test request will succeed on a fresh app"); + + // Check that the response looks as expected. + + assert_eq!(StatusCode::CREATED, response); +} + +#[tokio::test] +async fn accepts_repeat_subscription() { + let app = fixtures::scratch_app().await; + let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; + + // Issue a VAPID key. + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the VAPID key always succeeds"); + + // Find out what that VAPID key is. + + let boot = app.boot().snapshot().await.expect("boot always succeeds"); + let vapid = boot + .events + .into_iter() + .filter_map(event::vapid) + .filter_map(event::vapid::changed) + .next_back() + .expect("the application will have a vapid key after a refresh"); + + // Create a dummy subscription with that key. + + let request = super::Request { + subscription: super::Subscription { + endpoint: String::from("https://push.example.com/endpoint"), + keys: super::Keys { + p256dh: String::from("test-p256dh-value"), + auth: String::from("test-auth-value"), + }, + }, + vapid: vapid.key, + }; + let response = super::handler(State(app.push()), subscriber.clone(), Json(request.clone())) + .await + .expect("test request will succeed on a fresh app"); + + // Check that the response looks as expected. + + assert_eq!(StatusCode::CREATED, response); + + // Repeat the request + + let response = super::handler(State(app.push()), subscriber, Json(request)) + .await + .expect("test request will succeed twice on a fresh app"); + + // Check that the second response also looks as expected. + + assert_eq!(StatusCode::CREATED, response); +} + +#[tokio::test] +async fn rejects_duplicate_subscription() { + let app = fixtures::scratch_app().await; + let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; + + // Issue a VAPID key. + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the VAPID key always succeeds"); + + // Find out what that VAPID key is. + + let boot = app.boot().snapshot().await.expect("boot always succeeds"); + let vapid = boot + .events + .into_iter() + .filter_map(event::vapid) + .filter_map(event::vapid::changed) + .next_back() + .expect("the application will have a vapid key after a refresh"); + + // Create a dummy subscription with that key. + + let request = super::Request { + subscription: super::Subscription { + endpoint: String::from("https://push.example.com/endpoint"), + keys: super::Keys { + p256dh: String::from("test-p256dh-value"), + auth: String::from("test-auth-value"), + }, + }, + vapid: vapid.key, + }; + super::handler(State(app.push()), subscriber.clone(), Json(request)) + .await + .expect("test request will succeed on a fresh app"); + + // Repeat the request with different keys + + let request = super::Request { + subscription: super::Subscription { + endpoint: String::from("https://push.example.com/endpoint"), + keys: super::Keys { + p256dh: String::from("different-test-p256dh-value"), + auth: String::from("different-test-auth-value"), + }, + }, + vapid: vapid.key, + }; + let response = super::handler(State(app.push()), subscriber, Json(request)) + .await + .expect_err("request with duplicate endpoint should fail"); + + // Make sure we got the error we expected. + + assert!(matches!(response, super::Error(SubscribeError::Duplicate))); +} + +#[tokio::test] +async fn rejects_stale_vapid_key() { + let app = fixtures::scratch_app().await; + let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; + + // Issue a VAPID key. + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the VAPID key always succeeds"); + + // Find out what that VAPID key is. + + let boot = app.boot().snapshot().await.expect("boot always succeeds"); + let vapid = boot + .events + .into_iter() + .filter_map(event::vapid) + .filter_map(event::vapid::changed) + .next_back() + .expect("the application will have a vapid key after a refresh"); + + // Change the VAPID key. + + app.vapid() + .rotate_key() + .await + .expect("key rotation always succeeds"); + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the VAPID key always succeeds"); + + // Find out what the new VAPID key is. + + let boot = app.boot().snapshot().await.expect("boot always succeeds"); + let fresh_vapid = boot + .events + .into_iter() + .filter_map(event::vapid) + .filter_map(event::vapid::changed) + .next_back() + .expect("the application will have a vapid key after a refresh"); + + // Create a dummy subscription with the original key. + + let request = super::Request { + subscription: super::Subscription { + endpoint: String::from("https://push.example.com/endpoint"), + keys: super::Keys { + p256dh: String::from("test-p256dh-value"), + auth: String::from("test-auth-value"), + }, + }, + vapid: vapid.key, + }; + let response = super::handler(State(app.push()), subscriber, Json(request)) + .await + .expect_err("test request has a stale vapid key"); + + // Check that the response looks as expected. + + assert!(matches!( + response, + super::Error(SubscribeError::StaleVapidKey(key)) if key == fresh_vapid.key + )); +} diff --git a/src/push/mod.rs b/src/push/mod.rs new file mode 100644 index 0000000..1394ea4 --- /dev/null +++ b/src/push/mod.rs @@ -0,0 +1,3 @@ +pub mod app; +pub mod handlers; +pub mod repo; diff --git a/src/push/repo.rs b/src/push/repo.rs new file mode 100644 index 0000000..4183489 --- /dev/null +++ b/src/push/repo.rs @@ -0,0 +1,149 @@ +use sqlx::{Sqlite, SqliteConnection, Transaction}; +use web_push::SubscriptionInfo; + +use crate::{login::Login, token::Token}; + +pub trait Provider { + fn push(&mut self) -> Push<'_>; +} + +impl Provider for Transaction<'_, Sqlite> { + fn push(&mut self) -> Push<'_> { + Push(self) + } +} + +pub struct Push<'t>(&'t mut SqliteConnection); + +impl Push<'_> { + pub async fn create( + &mut self, + token: &Token, + subscription: &SubscriptionInfo, + ) -> Result<(), sqlx::Error> { + sqlx::query!( + r#" + insert into push_subscription (token, endpoint, p256dh, auth) + values ($1, $2, $3, $4) + "#, + token.id, + subscription.endpoint, + subscription.keys.p256dh, + subscription.keys.auth, + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } + + pub async fn by_login(&mut self, login: &Login) -> Result<Vec<SubscriptionInfo>, sqlx::Error> { + sqlx::query!( + r#" + select + subscription.endpoint, + subscription.p256dh, + subscription.auth + from push_subscription as subscription + join token on subscription.token = token.id + where token.login = $1 + "#, + login.id, + ) + .map(|row| SubscriptionInfo::new(row.endpoint, row.p256dh, row.auth)) + .fetch_all(&mut *self.0) + .await + } + + pub async fn by_endpoint( + &mut self, + subscriber: &Login, + endpoint: &str, + ) -> Result<SubscriptionInfo, sqlx::Error> { + let row = sqlx::query!( + r#" + select + subscription.endpoint, + subscription.p256dh, + subscription.auth + from push_subscription as subscription + join token on subscription.token = token.id + join login as subscriber on token.login = subscriber.id + where subscriber.id = $1 + and subscription.endpoint = $2 + "#, + subscriber.id, + endpoint, + ) + .fetch_one(&mut *self.0) + .await?; + + let info = SubscriptionInfo::new(row.endpoint, row.p256dh, row.auth); + + Ok(info) + } + + pub async fn unsubscribe( + &mut self, + subscription: &SubscriptionInfo, + ) -> Result<(), sqlx::Error> { + sqlx::query!( + r#" + delete from push_subscription + where endpoint = $1 + "#, + subscription.endpoint, + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } + + pub async fn unsubscribe_token(&mut self, token: &Token) -> Result<(), sqlx::Error> { + sqlx::query!( + r#" + delete from push_subscription + where token = $1 + "#, + token.id, + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } + + pub async fn unsubscribe_login(&mut self, login: &Login) -> Result<(), sqlx::Error> { + sqlx::query!( + r#" + with tokens as ( + select id from token + where login = $1 + ) + delete from push_subscription + where token in tokens + "#, + login.id, + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } + + // Unsubscribe logic for token expiry lives in the `tokens` repository, for maintenance reasons. + + pub async fn clear(&mut self) -> Result<(), sqlx::Error> { + // We assume that _all_ stored subscriptions are for a VAPID key we're about to delete. + sqlx::query!( + r#" + delete from push_subscription + "#, + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } +} diff --git a/src/routes.rs b/src/routes.rs index b848afb..1c07e78 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -3,10 +3,16 @@ use axum::{ response::Redirect, routing::{delete, get, post}, }; +use web_push::WebPushClient; -use crate::{app::App, boot, conversation, event, expire, invite, login, message, setup, ui}; +use crate::{ + app::App, boot, conversation, event, expire, invite, login, message, push, setup, ui, vapid, +}; -pub fn routes(app: &App) -> Router<App> { +pub fn routes<P>(app: &App<P>) -> Router<App<P>> +where + P: WebPushClient + Clone + Send + Sync + 'static, +{ // UI routes that can be accessed before the administrator completes setup. let ui_bootstrap = Router::new() .route("/{*path}", get(ui::handlers::asset)) @@ -44,6 +50,8 @@ pub fn routes(app: &App) -> Router<App> { .route("/api/invite/{invite}", get(invite::handlers::get)) .route("/api/invite/{invite}", post(invite::handlers::accept)) .route("/api/messages/{message}", delete(message::handlers::delete)) + .route("/api/push/ping", post(push::handlers::ping)) + .route("/api/push/subscribe", post(push::handlers::subscribe)) .route("/api/password", post(login::handlers::change_password)) // Run expiry whenever someone accesses the API. This was previously a blanket middleware // affecting the whole service, but loading the client makes a several requests before the @@ -56,6 +64,10 @@ pub fn routes(app: &App) -> Router<App> { app.clone(), expire::middleware, )) + .route_layer(middleware::from_fn_with_state( + app.clone(), + vapid::middleware, + )) .route_layer(setup::Required(app.clone())); [ diff --git a/src/test/fixtures/event/mod.rs b/src/test/fixtures/event/mod.rs index 08b17e7..f8651ba 100644 --- a/src/test/fixtures/event/mod.rs +++ b/src/test/fixtures/event/mod.rs @@ -23,6 +23,13 @@ pub fn user(event: Event) -> Option<crate::user::Event> { } } +pub fn vapid(event: Event) -> Option<crate::vapid::Event> { + match event { + Event::Vapid(event) => Some(event), + _ => None, + } +} + pub mod conversation { use crate::conversation::{Event, event}; @@ -72,3 +79,17 @@ pub mod user { } } } + +pub mod vapid { + use crate::vapid::{Event, event}; + + // This could be defined as `-> event::Changed`. However, I want the interface to be consistent + // with the event stream transformers for other types, and we'd have to refactor the return type + // to `-> Option<event::Changed>` the instant VAPID keys sprout a second event. + #[allow(clippy::unnecessary_wraps)] + pub fn changed(event: Event) -> Option<event::Changed> { + match event { + Event::Changed(changed) => Some(changed), + } + } +} diff --git a/src/test/fixtures/event/stream.rs b/src/test/fixtures/event/stream.rs index 5b3621d..bb83d0d 100644 --- a/src/test/fixtures/event/stream.rs +++ b/src/test/fixtures/event/stream.rs @@ -14,6 +14,10 @@ pub fn user(event: Event) -> Ready<Option<crate::user::Event>> { future::ready(event::user(event)) } +pub fn vapid(event: Event) -> Ready<Option<crate::vapid::Event>> { + future::ready(event::vapid(event)) +} + pub mod conversation { use std::future::{self, Ready}; @@ -60,3 +64,16 @@ pub mod user { future::ready(user::created(event)) } } + +pub mod vapid { + use std::future::{self, Ready}; + + use crate::{ + test::fixtures::event::vapid, + vapid::{Event, event}, + }; + + pub fn changed(event: Event) -> Ready<Option<event::Changed>> { + future::ready(vapid::changed(event)) + } +} diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs index 20929f9..adc3e73 100644 --- a/src/test/fixtures/identity.rs +++ b/src/test/fixtures/identity.rs @@ -14,7 +14,7 @@ use crate::{ }, }; -pub async fn create(app: &App, created_at: &RequestedAt) -> Identity { +pub async fn create<P>(app: &App<P>, created_at: &RequestedAt) -> Identity { let credentials = fixtures::user::create_with_password(app, created_at).await; logged_in(app, &credentials, created_at).await } diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs index d9aca81..839a412 100644 --- a/src/test/fixtures/login.rs +++ b/src/test/fixtures/login.rs @@ -5,7 +5,7 @@ use crate::{ test::fixtures::user::{propose, propose_name}, }; -pub async fn create(app: &App, created_at: &DateTime) -> Login { +pub async fn create<P>(app: &App<P>, created_at: &DateTime) -> Login { let (name, password) = propose(); app.users() .create(&name, &password, created_at) diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index 3d69cfa..53bf31b 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -1,6 +1,6 @@ use chrono::{TimeDelta, Utc}; -use crate::{app::App, clock::RequestedAt, db}; +use crate::{app::App, clock::RequestedAt, db, test::webpush::Client}; pub mod boot; pub mod conversation; @@ -13,11 +13,11 @@ pub mod login; pub mod message; pub mod user; -pub async fn scratch_app() -> App { +pub async fn scratch_app() -> App<Client> { let pool = db::prepare("sqlite::memory:", "sqlite::memory:") .await .expect("setting up in-memory sqlite database"); - App::from(pool) + App::from(pool, Client::new()) } pub fn now() -> RequestedAt { diff --git a/src/test/fixtures/user.rs b/src/test/fixtures/user.rs index d4d8db4..3ad4436 100644 --- a/src/test/fixtures/user.rs +++ b/src/test/fixtures/user.rs @@ -3,7 +3,7 @@ use uuid::Uuid; use crate::{app::App, clock::RequestedAt, login::Login, name::Name, password::Password}; -pub async fn create_with_password(app: &App, created_at: &RequestedAt) -> (Name, Password) { +pub async fn create_with_password<P>(app: &App<P>, created_at: &RequestedAt) -> (Name, Password) { let (name, password) = propose(); let user = app .users() @@ -14,7 +14,7 @@ pub async fn create_with_password(app: &App, created_at: &RequestedAt) -> (Name, (user.name, password) } -pub async fn create(app: &App, created_at: &RequestedAt) -> Login { +pub async fn create<P>(app: &App<P>, created_at: &RequestedAt) -> Login { super::login::create(app, created_at).await } diff --git a/src/test/mod.rs b/src/test/mod.rs index ebbbfef..f798b9c 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,2 +1,3 @@ pub mod fixtures; pub mod verify; +pub mod webpush; diff --git a/src/test/webpush.rs b/src/test/webpush.rs new file mode 100644 index 0000000..c86d03f --- /dev/null +++ b/src/test/webpush.rs @@ -0,0 +1,37 @@ +use std::{ + mem, + sync::{Arc, Mutex}, +}; + +use web_push::{WebPushClient, WebPushError, WebPushMessage}; + +#[derive(Clone)] +pub struct Client { + sent: Arc<Mutex<Vec<WebPushMessage>>>, +} + +impl Client { + pub fn new() -> Self { + Self { + sent: Arc::default(), + } + } + + // Clears the list of sent messages (for all clones of this Client) when called, because we + // can't clone `WebPushMessage`s so we either need to move them or try to reconstruct them, + // either of which sucks but moving them sucks less. + pub fn sent(&self) -> Vec<WebPushMessage> { + let mut sent = self.sent.lock().unwrap(); + mem::replace(&mut *sent, Vec::new()) + } +} + +#[async_trait::async_trait] +impl WebPushClient for Client { + async fn send(&self, message: WebPushMessage) -> Result<(), WebPushError> { + let mut sent = self.sent.lock().unwrap(); + sent.push(message); + + Ok(()) + } +} diff --git a/src/token/app.rs b/src/token/app.rs index 332473d..4a08877 100644 --- a/src/token/app.rs +++ b/src/token/app.rs @@ -10,7 +10,7 @@ use super::{ extract::Identity, repo::{self, Provider as _}, }; -use crate::{clock::DateTime, db::NotFound as _, name}; +use crate::{clock::DateTime, db::NotFound as _, name, push::repo::Provider as _}; pub struct Tokens { db: SqlitePool, @@ -112,6 +112,7 @@ impl Tokens { pub async fn logout(&self, token: &Token) -> Result<(), ValidateError> { let mut tx = self.db.begin().await?; + tx.push().unsubscribe_token(token).await?; tx.tokens().revoke(token).await?; tx.commit().await?; diff --git a/src/token/repo/token.rs b/src/token/repo/token.rs index 52a3987..33c33af 100644 --- a/src/token/repo/token.rs +++ b/src/token/repo/token.rs @@ -89,6 +89,23 @@ impl Tokens<'_> { // Expire and delete all tokens that haven't been used more recently than // `expire_at`. pub async fn expire(&mut self, expire_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> { + // This lives here, rather than in the `push` repository, to ensure that the criteria for + // stale tokens don't drift apart between the two queries. That would be a larger risk if + // the queries lived in very separate parts of the codebase. + sqlx::query!( + r#" + with stale_tokens as ( + select id from token + where last_used_at < $1 + ) + delete from push_subscription + where token in stale_tokens + "#, + expire_at, + ) + .execute(&mut *self.0) + .await?; + let tokens = sqlx::query_scalar!( r#" delete diff --git a/src/vapid/app.rs b/src/vapid/app.rs new file mode 100644 index 0000000..9949aa5 --- /dev/null +++ b/src/vapid/app.rs @@ -0,0 +1,117 @@ +use chrono::TimeDelta; +use sqlx::SqlitePool; + +use super::{History, repo, repo::Provider as _}; +use crate::{ + clock::DateTime, + db::NotFound as _, + event::{Broadcaster, Sequence, repo::Provider}, + push::repo::Provider as _, +}; + +pub struct Vapid { + db: SqlitePool, + events: Broadcaster, +} + +impl Vapid { + pub const fn new(db: SqlitePool, events: Broadcaster) -> Self { + Self { db, events } + } + + pub async fn rotate_key(&self) -> Result<(), sqlx::Error> { + let mut tx = self.db.begin().await?; + // This is called from a separate CLI utility (see `cli.rs`), and we _can't_ deliver events + // to active clients from another process, so don't do anything that would require us to + // send events, like generating a new key. + // + // Instead, the server's next `refresh_key` call will generate a key and notify clients + // of the change. All we have to do is remove the existing key, so that the server can know + // to do so. + tx.vapid().clear().await?; + // Delete outstanding subscriptions for the existing VAPID key, as well. They're + // unserviceable once we lose the key. Clients can resubscribe when they process the next + // key rotation event, which will be quite quickly once the running server notices that the + // VAPID key has been removed. + tx.push().clear().await?; + tx.commit().await?; + + Ok(()) + } + + pub async fn refresh_key(&self, ensure_at: &DateTime) -> Result<(), Error> { + let mut tx = self.db.begin().await?; + let key = tx.vapid().current().await.optional()?; + if key.is_none() { + let changed_at = tx.sequence().next(ensure_at).await?; + let (key, secret) = History::begin(&changed_at); + + tx.vapid().clear().await?; + tx.vapid().store_signing_key(&secret).await?; + + let events = key.events().filter(Sequence::start_from(changed_at)); + tx.vapid().record_events(events.clone()).await?; + + tx.commit().await?; + + self.events.broadcast_from(events); + } else if let Some(key) = key + // Somewhat arbitrarily, rotate keys every 30 days. + && key.older_than(ensure_at.to_owned() - TimeDelta::days(30)) + { + // If you can think of a way to factor out this duplication, be my guest. I tried. + // The only approach I could think of mirrors `crate::user::create::Create`, encoding + // the process in a state machine made of types, and that's a very complex solution + // to a problem that doesn't seem to merit it. -o + let changed_at = tx.sequence().next(ensure_at).await?; + let (key, secret) = key.rotate(&changed_at); + + // This will delete _all_ stored subscriptions. This is fine; they're all for the + // current VAPID key, and we won't be able to use them anyways once the key is rotated. + // We have no way to inform the push broker services of that, unfortunately. + tx.push().clear().await?; + tx.vapid().clear().await?; + tx.vapid().store_signing_key(&secret).await?; + + // Refactoring constraint: this `events` iterator borrows `key`. Anything that moves + // `key` has to give it back, but it can't give both `key` back and an event iterator + // borrowing from `key` because Rust doesn't support types that borrow from other + // parts of themselves. + let events = key.events().filter(Sequence::start_from(changed_at)); + tx.vapid().record_events(events.clone()).await?; + + // Refactoring constraint: we _really_ want to commit the transaction before we send + // out events, so that anything acting on those events is guaranteed to see the state + // of the service at some point at or after the side effects of this. I'd also prefer + // to keep the commit in the same method that the transaction is begun in, for clarity. + tx.commit().await?; + + self.events.broadcast_from(events); + } + // else, the key exists and is not stale. Don't bother allocating a sequence number, and + // in fact throw away the whole transaction. + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum Error { + Database(#[from] sqlx::Error), + Ecdsa(#[from] p256::ecdsa::Error), + Pkcs8(#[from] p256::pkcs8::Error), + WebPush(#[from] web_push::WebPushError), +} + +impl From<repo::Error> for Error { + fn from(error: repo::Error) -> Self { + use repo::Error; + match error { + Error::Database(error) => error.into(), + Error::Ecdsa(error) => error.into(), + Error::Pkcs8(error) => error.into(), + Error::WebPush(error) => error.into(), + } + } +} diff --git a/src/vapid/event.rs b/src/vapid/event.rs new file mode 100644 index 0000000..cf3be77 --- /dev/null +++ b/src/vapid/event.rs @@ -0,0 +1,37 @@ +use p256::ecdsa::VerifyingKey; + +use crate::event::{Instant, Sequenced}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "event", rename_all = "snake_case")] +pub enum Event { + Changed(Changed), +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + match self { + Self::Changed(event) => event.instant(), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Changed { + #[serde(flatten)] + pub instant: Instant, + #[serde(with = "crate::vapid::ser::key")] + pub key: VerifyingKey, +} + +impl From<Changed> for Event { + fn from(event: Changed) -> Self { + Self::Changed(event) + } +} + +impl Sequenced for Changed { + fn instant(&self) -> Instant { + self.instant + } +} diff --git a/src/vapid/history.rs b/src/vapid/history.rs new file mode 100644 index 0000000..42f062b --- /dev/null +++ b/src/vapid/history.rs @@ -0,0 +1,55 @@ +use p256::ecdsa::{SigningKey, VerifyingKey}; +use rand::thread_rng; + +use super::event::{Changed, Event}; +use crate::{clock::DateTime, event::Instant}; + +#[derive(Debug)] +pub struct History { + pub key: VerifyingKey, + pub changed: Instant, +} + +// Lifecycle interface +impl History { + pub fn begin(changed: &Instant) -> (Self, SigningKey) { + let key = SigningKey::random(&mut thread_rng()); + ( + Self { + key: VerifyingKey::from(&key), + changed: *changed, + }, + key, + ) + } + + // `self` _is_ unused here, clippy is right about that. This choice is deliberate, however - it + // makes it harder to inadvertently reuse a rotated key via its history, and it makes the + // lifecycle interface more obviously consistent between this and other History types. + #[allow(clippy::unused_self)] + pub fn rotate(self, changed: &Instant) -> (Self, SigningKey) { + Self::begin(changed) + } +} + +// State interface +impl History { + pub fn older_than(&self, when: DateTime) -> bool { + self.changed.at < when + } +} + +// Events interface +impl History { + pub fn events(&self) -> impl Iterator<Item = Event> + Clone { + [self.changed()].into_iter() + } + + fn changed(&self) -> Event { + Changed { + key: self.key, + instant: self.changed, + } + .into() + } +} diff --git a/src/vapid/middleware.rs b/src/vapid/middleware.rs new file mode 100644 index 0000000..3129aa7 --- /dev/null +++ b/src/vapid/middleware.rs @@ -0,0 +1,17 @@ +use axum::{ + extract::{Request, State}, + middleware::Next, + response::Response, +}; + +use crate::{clock::RequestedAt, error::Internal, vapid::app::Vapid}; + +pub async fn middleware( + State(vapid): State<Vapid>, + RequestedAt(now): RequestedAt, + request: Request, + next: Next, +) -> Result<Response, Internal> { + vapid.refresh_key(&now).await?; + Ok(next.run(request).await) +} diff --git a/src/vapid/mod.rs b/src/vapid/mod.rs new file mode 100644 index 0000000..364f602 --- /dev/null +++ b/src/vapid/mod.rs @@ -0,0 +1,10 @@ +pub mod app; +pub mod event; +mod history; +mod middleware; +pub mod repo; +pub mod ser; + +pub use event::Event; +pub use history::History; +pub use middleware::middleware; diff --git a/src/vapid/repo.rs b/src/vapid/repo.rs new file mode 100644 index 0000000..9db61e1 --- /dev/null +++ b/src/vapid/repo.rs @@ -0,0 +1,161 @@ +use std::io::Cursor; + +use p256::{ + ecdsa::SigningKey, + pkcs8::{DecodePrivateKey as _, EncodePrivateKey as _, LineEnding}, +}; +use sqlx::{Sqlite, SqliteConnection, Transaction}; +use web_push::{PartialVapidSignatureBuilder, VapidSignatureBuilder}; + +use super::{ + History, + event::{Changed, Event}, +}; +use crate::{ + clock::DateTime, + db::NotFound, + event::{Instant, Sequence}, +}; + +pub trait Provider { + fn vapid(&mut self) -> Vapid<'_>; +} + +impl Provider for Transaction<'_, Sqlite> { + fn vapid(&mut self) -> Vapid<'_> { + Vapid(self) + } +} + +pub struct Vapid<'a>(&'a mut SqliteConnection); + +impl Vapid<'_> { + pub async fn record_events( + &mut self, + events: impl IntoIterator<Item = Event>, + ) -> Result<(), sqlx::Error> { + for event in events { + self.record_event(&event).await?; + } + Ok(()) + } + + pub async fn record_event(&mut self, event: &Event) -> Result<(), sqlx::Error> { + match event { + Event::Changed(changed) => self.record_changed(changed).await, + } + } + + async fn record_changed(&mut self, changed: &Changed) -> Result<(), sqlx::Error> { + sqlx::query!( + r#" + insert into vapid_key (changed_at, changed_sequence) + values ($1, $2) + "#, + changed.instant.at, + changed.instant.sequence, + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } + + pub async fn clear(&mut self) -> Result<(), sqlx::Error> { + sqlx::query!( + r#" + delete from vapid_key + "# + ) + .execute(&mut *self.0) + .await?; + + sqlx::query!( + r#" + delete from vapid_signing_key + "# + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } + + pub async fn store_signing_key(&mut self, key: &SigningKey) -> Result<(), Error> { + let key = key.to_pkcs8_pem(LineEnding::CRLF)?; + let key = key.as_str(); + sqlx::query!( + r#" + insert into vapid_signing_key (key) + values ($1) + "#, + key, + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } + + pub async fn current(&mut self) -> Result<History, Error> { + let key = sqlx::query!( + r#" + select + key.changed_at as "changed_at: DateTime", + key.changed_sequence as "changed_sequence: Sequence", + signing.key + from vapid_key as key + join vapid_signing_key as signing + "# + ) + .map(|row| { + let key = SigningKey::from_pkcs8_pem(&row.key)?; + let key = key.verifying_key().to_owned(); + + let changed = Instant::new(row.changed_at, row.changed_sequence); + + Ok::<_, Error>(History { key, changed }) + }) + .fetch_one(&mut *self.0) + .await??; + + Ok(key) + } + + pub async fn signer(&mut self) -> Result<PartialVapidSignatureBuilder, Error> { + let key = sqlx::query_scalar!( + r#" + select key + from vapid_signing_key + "# + ) + .fetch_one(&mut *self.0) + .await?; + let key = Cursor::new(&key); + let signer = VapidSignatureBuilder::from_pem_no_sub(key)?; + + Ok(signer) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum Error { + Ecdsa(#[from] p256::ecdsa::Error), + Pkcs8(#[from] p256::pkcs8::Error), + WebPush(#[from] web_push::WebPushError), + Database(#[from] sqlx::Error), +} + +impl<T> NotFound for Result<T, Error> { + type Ok = T; + type Error = Error; + + fn optional(self) -> Result<Option<T>, Error> { + match self { + Ok(value) => Ok(Some(value)), + Err(Error::Database(sqlx::Error::RowNotFound)) => Ok(None), + Err(other) => Err(other), + } + } +} diff --git a/src/vapid/ser.rs b/src/vapid/ser.rs new file mode 100644 index 0000000..02c77e1 --- /dev/null +++ b/src/vapid/ser.rs @@ -0,0 +1,63 @@ +pub mod key { + use std::fmt; + + use base64::{Engine as _, engine::general_purpose::URL_SAFE}; + use p256::ecdsa::VerifyingKey; + use serde::{Deserializer, Serialize as _, de}; + + // This serialization - to a URL-safe base-64-encoded string and back - is based on my best + // understanding of RFC 8292 and the corresponding browser APIs. Particularly, it's based on + // section 3.2: + // + // > The "k" parameter includes an ECDSA public key [FIPS186] in uncompressed form [X9.62] that + // > is encoded using base64url encoding [RFC7515]. + // + // <https://datatracker.ietf.org/doc/html/rfc8292#section-3.2> + // + // I believe this is also supported by MDN's explanation: + // + // > `applicationServerKey` + // > + // > A Base64-encoded string or ArrayBuffer containing an ECDSA P-256 public key that the push + // > server will use to authenticate your application server. If specified, all messages from + // > your application server must use the VAPID authentication scheme, and include a JWT signed + // > with the corresponding private key. This key IS NOT the same ECDH key that you use to + // > encrypt the data. For more information, see "Using VAPID with WebPush". + // + // <https://developer.mozilla.org/en-US/docs/Web/API/PushManager/subscribe#applicationserverkey> + + pub fn serialize<S>(key: &VerifyingKey, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + let key = key.to_sec1_bytes(); + let key = URL_SAFE.encode(key); + key.serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result<VerifyingKey, D::Error> + where + D: Deserializer<'de>, + { + deserializer.deserialize_str(Visitor) + } + + struct Visitor; + impl de::Visitor<'_> for Visitor { + type Value = VerifyingKey; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string containing a VAPID key") + } + + fn visit_str<E>(self, key: &str) -> Result<Self::Value, E> + where + E: de::Error, + { + let key = URL_SAFE.decode(key).map_err(E::custom)?; + let key = VerifyingKey::from_sec1_bytes(&key).map_err(E::custom)?; + + Ok(key) + } + } +} |
