summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2025-11-08 16:28:10 -0500
committerOwen Jacobson <owen@grimoire.ca>2025-11-08 16:28:10 -0500
commitfc6914831743f6d683c59adb367479defe6f8b3a (patch)
tree5b997adac55f47b52f30022013b8ec3b2c10bcc5 /src
parent0ef69c7d256380e660edc45ace7f1d6151226340 (diff)
parent6bab5b4405c9adafb2ce76540595a62eea80acc0 (diff)
Integrate the prototype push notification support.
We're going to move forwards with this for now, as low-utility as it is, so that we can more easily iterate on it in a real-world environment (hi.grimoire.ca).
Diffstat (limited to 'src')
-rw-r--r--src/app.rs71
-rw-r--r--src/boot/app.rs27
-rw-r--r--src/boot/handlers/boot/test.rs68
-rw-r--r--src/cli.rs32
-rw-r--r--src/event/app.rs26
-rw-r--r--src/event/handlers/stream/mod.rs4
-rw-r--r--src/event/handlers/stream/test/mod.rs1
-rw-r--r--src/event/handlers/stream/test/vapid.rs111
-rw-r--r--src/event/mod.rs10
-rw-r--r--src/expire.rs4
-rw-r--r--src/lib.rs2
-rw-r--r--src/login/app.rs2
-rw-r--r--src/push/app.rs178
-rw-r--r--src/push/handlers/mod.rs5
-rw-r--r--src/push/handlers/ping/mod.rs23
-rw-r--r--src/push/handlers/ping/test.rs40
-rw-r--r--src/push/handlers/subscribe/mod.rs94
-rw-r--r--src/push/handlers/subscribe/test.rs236
-rw-r--r--src/push/mod.rs3
-rw-r--r--src/push/repo.rs149
-rw-r--r--src/routes.rs16
-rw-r--r--src/test/fixtures/event/mod.rs21
-rw-r--r--src/test/fixtures/event/stream.rs17
-rw-r--r--src/test/fixtures/identity.rs2
-rw-r--r--src/test/fixtures/login.rs2
-rw-r--r--src/test/fixtures/mod.rs6
-rw-r--r--src/test/fixtures/user.rs4
-rw-r--r--src/test/mod.rs1
-rw-r--r--src/test/webpush.rs37
-rw-r--r--src/token/app.rs3
-rw-r--r--src/token/repo/token.rs17
-rw-r--r--src/vapid/app.rs117
-rw-r--r--src/vapid/event.rs37
-rw-r--r--src/vapid/history.rs55
-rw-r--r--src/vapid/middleware.rs17
-rw-r--r--src/vapid/mod.rs10
-rw-r--r--src/vapid/repo.rs161
-rw-r--r--src/vapid/ser.rs63
38 files changed, 1635 insertions, 37 deletions
diff --git a/src/app.rs b/src/app.rs
index ad19bc0..098ae9f 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -10,30 +10,34 @@ use crate::{
invite::app::Invites,
login::app::Logins,
message::app::Messages,
+ push::app::Push,
setup::app::Setup,
token::{self, app::Tokens},
+ vapid::app::Vapid,
};
#[derive(Clone)]
-pub struct App {
+pub struct App<P> {
db: SqlitePool,
+ webpush: P,
events: event::Broadcaster,
token_events: token::Broadcaster,
}
-impl App {
- pub fn from(db: SqlitePool) -> Self {
+impl<P> App<P> {
+ pub fn from(db: SqlitePool, webpush: P) -> Self {
let events = event::Broadcaster::default();
let token_events = token::Broadcaster::default();
Self {
db,
+ webpush,
events,
token_events,
}
}
}
-impl App {
+impl<P> App<P> {
pub fn boot(&self) -> Boot {
Boot::new(self.db.clone())
}
@@ -58,6 +62,13 @@ impl App {
Messages::new(self.db.clone(), self.events.clone())
}
+ pub fn push(&self) -> Push<P>
+ where
+ P: Clone,
+ {
+ Push::new(self.db.clone(), self.webpush.clone())
+ }
+
pub fn setup(&self) -> Setup {
Setup::new(self.db.clone(), self.events.clone())
}
@@ -70,46 +81,70 @@ impl App {
pub fn users(&self) -> Users {
Users::new(self.db.clone(), self.events.clone())
}
+
+ pub fn vapid(&self) -> Vapid {
+ Vapid::new(self.db.clone(), self.events.clone())
+ }
+
+ #[cfg(test)]
+ pub fn webpush(&self) -> &P {
+ &self.webpush
+ }
}
-impl FromRef<App> for Boot {
- fn from_ref(app: &App) -> Self {
+impl<P> FromRef<App<P>> for Boot {
+ fn from_ref(app: &App<P>) -> Self {
app.boot()
}
}
-impl FromRef<App> for Conversations {
- fn from_ref(app: &App) -> Self {
+impl<P> FromRef<App<P>> for Conversations {
+ fn from_ref(app: &App<P>) -> Self {
app.conversations()
}
}
-impl FromRef<App> for Invites {
- fn from_ref(app: &App) -> Self {
+impl<P> FromRef<App<P>> for Invites {
+ fn from_ref(app: &App<P>) -> Self {
app.invites()
}
}
-impl FromRef<App> for Logins {
- fn from_ref(app: &App) -> Self {
+impl<P> FromRef<App<P>> for Logins {
+ fn from_ref(app: &App<P>) -> Self {
app.logins()
}
}
-impl FromRef<App> for Messages {
- fn from_ref(app: &App) -> Self {
+impl<P> FromRef<App<P>> for Messages {
+ fn from_ref(app: &App<P>) -> Self {
app.messages()
}
}
-impl FromRef<App> for Setup {
- fn from_ref(app: &App) -> Self {
+impl<P> FromRef<App<P>> for Push<P>
+where
+ P: Clone,
+{
+ fn from_ref(app: &App<P>) -> Self {
+ app.push()
+ }
+}
+
+impl<P> FromRef<App<P>> for Setup {
+ fn from_ref(app: &App<P>) -> Self {
app.setup()
}
}
-impl FromRef<App> for Tokens {
- fn from_ref(app: &App) -> Self {
+impl<P> FromRef<App<P>> for Tokens {
+ fn from_ref(app: &App<P>) -> Self {
app.tokens()
}
}
+
+impl<P> FromRef<App<P>> for Vapid {
+ fn from_ref(app: &App<P>) -> Self {
+ app.vapid()
+ }
+}
diff --git a/src/boot/app.rs b/src/boot/app.rs
index 840243e..1ca8adb 100644
--- a/src/boot/app.rs
+++ b/src/boot/app.rs
@@ -4,10 +4,12 @@ use sqlx::sqlite::SqlitePool;
use super::Snapshot;
use crate::{
conversation::{self, repo::Provider as _},
+ db::NotFound,
event::{Event, Sequence, repo::Provider as _},
message::{self, repo::Provider as _},
name,
user::{self, repo::Provider as _},
+ vapid::{self, repo::Provider as _},
};
pub struct Boot {
@@ -26,6 +28,7 @@ impl Boot {
let users = tx.users().all(resume_point).await?;
let conversations = tx.conversations().all(resume_point).await?;
let messages = tx.messages().all(resume_point).await?;
+ let vapid = tx.vapid().current().await.optional()?;
tx.commit().await?;
@@ -50,9 +53,16 @@ impl Boot {
.filter(Sequence::up_to(resume_point))
.map(Event::from);
+ let vapid_events = vapid
+ .iter()
+ .flat_map(vapid::History::events)
+ .filter(Sequence::up_to(resume_point))
+ .map(Event::from);
+
let events = user_events
.merge_by(conversation_events, Sequence::merge)
.merge_by(message_events, Sequence::merge)
+ .merge_by(vapid_events, Sequence::merge)
.collect();
Ok(Snapshot {
@@ -65,8 +75,11 @@ impl Boot {
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub enum Error {
- Name(#[from] name::Error),
Database(#[from] sqlx::Error),
+ Name(#[from] name::Error),
+ Ecdsa(#[from] p256::ecdsa::Error),
+ Pkcs8(#[from] p256::pkcs8::Error),
+ WebPush(#[from] web_push::WebPushError),
}
impl From<user::repo::LoadError> for Error {
@@ -88,3 +101,15 @@ impl From<conversation::repo::LoadError> for Error {
}
}
}
+
+impl From<vapid::repo::Error> for Error {
+ fn from(error: vapid::repo::Error) -> Self {
+ use vapid::repo::Error;
+ match error {
+ Error::Database(error) => error.into(),
+ Error::Ecdsa(error) => error.into(),
+ Error::Pkcs8(error) => error.into(),
+ Error::WebPush(error) => error.into(),
+ }
+ }
+}
diff --git a/src/boot/handlers/boot/test.rs b/src/boot/handlers/boot/test.rs
index a9891eb..f192478 100644
--- a/src/boot/handlers/boot/test.rs
+++ b/src/boot/handlers/boot/test.rs
@@ -81,6 +81,74 @@ async fn includes_messages() {
}
#[tokio::test]
+async fn includes_vapid_key() {
+ let app = fixtures::scratch_app().await;
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("key rotation always succeeds");
+
+ let viewer = fixtures::identity::fictitious();
+ let response = super::handler(State(app.boot()), viewer)
+ .await
+ .expect("boot always succeeds");
+
+ response
+ .snapshot
+ .events
+ .into_iter()
+ .filter_map(fixtures::event::vapid)
+ .filter_map(fixtures::event::vapid::changed)
+ .exactly_one()
+ .expect("only one vapid key has been created");
+}
+
+#[tokio::test]
+async fn includes_only_latest_vapid_key() {
+ let app = fixtures::scratch_app().await;
+
+ app.vapid()
+ .refresh_key(&fixtures::ancient())
+ .await
+ .expect("key rotation always succeeds");
+
+ let viewer = fixtures::identity::fictitious();
+ let response = super::handler(State(app.boot()), viewer.clone())
+ .await
+ .expect("boot always succeeds");
+
+ let original_key = response
+ .snapshot
+ .events
+ .into_iter()
+ .filter_map(fixtures::event::vapid)
+ .filter_map(fixtures::event::vapid::changed)
+ .exactly_one()
+ .expect("only one vapid key has been created");
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("key rotation always succeeds");
+
+ let response = super::handler(State(app.boot()), viewer)
+ .await
+ .expect("boot always succeeds");
+
+ let rotated_key = response
+ .snapshot
+ .events
+ .into_iter()
+ .filter_map(fixtures::event::vapid)
+ .filter_map(fixtures::event::vapid::changed)
+ .exactly_one()
+ .expect("only one vapid key should be returned");
+
+ assert_ne!(original_key, rotated_key);
+}
+
+#[tokio::test]
async fn includes_expired_messages() {
let app = fixtures::scratch_app().await;
let sender = fixtures::user::create(&app, &fixtures::ancient()).await;
diff --git a/src/cli.rs b/src/cli.rs
index 378686b..154771b 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -10,9 +10,10 @@ use axum::{
middleware,
response::{IntoResponse, Response},
};
-use clap::{CommandFactory, Parser};
+use clap::{CommandFactory, Parser, Subcommand};
use sqlx::sqlite::SqlitePool;
use tokio::net;
+use web_push::{IsahcWebPushClient, WebPushClient};
use crate::{
app::App,
@@ -65,6 +66,15 @@ pub struct Args {
/// upgrades
#[arg(short = 'D', long, env, default_value = "sqlite://pilcrow.db.backup")]
backup_database_url: String,
+
+ #[command(subcommand)]
+ command: Option<Command>,
+}
+
+#[derive(Subcommand)]
+enum Command {
+ /// Immediately rotate the server's VAPID (Web Push) application key.
+ RotateVapidKey,
}
impl Args {
@@ -88,7 +98,21 @@ impl Args {
self.umask.set();
let pool = self.pool().await?;
- let app = App::from(pool);
+ let webpush = IsahcWebPushClient::new()?;
+ let app = App::from(pool, webpush);
+
+ match self.command {
+ None => self.serve(app).await?,
+ Some(Command::RotateVapidKey) => app.vapid().rotate_key().await?,
+ }
+
+ Result::<_, Error>::Ok(())
+ }
+
+ async fn serve<P>(self, app: App<P>) -> Result<(), Error>
+ where
+ P: WebPushClient + Clone + Send + Sync + 'static,
+ {
let app = routes::routes(&app)
.route_layer(middleware::from_fn(clock::middleware))
.route_layer(middleware::map_response(Self::server_info()))
@@ -101,7 +125,7 @@ impl Args {
println!("{started_msg}");
serve.await?;
- Result::<_, Error>::Ok(())
+ Ok(())
}
async fn listener(&self) -> io::Result<net::TcpListener> {
@@ -140,5 +164,7 @@ fn started_msg(listener: &net::TcpListener) -> io::Result<String> {
enum Error {
Io(#[from] io::Error),
Database(#[from] db::Error),
+ Sqlx(#[from] sqlx::Error),
Umask(#[from] umask::Error),
+ Webpush(#[from] web_push::WebPushError),
}
diff --git a/src/event/app.rs b/src/event/app.rs
index 8fa760a..e422de9 100644
--- a/src/event/app.rs
+++ b/src/event/app.rs
@@ -8,9 +8,12 @@ use sqlx::sqlite::SqlitePool;
use super::{Event, Sequence, Sequenced, broadcaster::Broadcaster};
use crate::{
conversation::{self, repo::Provider as _},
+ db::NotFound,
message::{self, repo::Provider as _},
name,
user::{self, repo::Provider as _},
+ vapid,
+ vapid::repo::Provider as _,
};
pub struct Events {
@@ -57,9 +60,17 @@ impl Events {
.filter(Sequence::after(resume_at))
.map(Event::from);
+ let vapid = tx.vapid().current().await.optional()?;
+ let vapid_events = vapid
+ .iter()
+ .flat_map(vapid::History::events)
+ .filter(Sequence::after(resume_at))
+ .map(Event::from);
+
let replay_events = user_events
.merge_by(conversation_events, Sequence::merge)
.merge_by(message_events, Sequence::merge)
+ .merge_by(vapid_events, Sequence::merge)
.collect::<Vec<_>>();
let resume_live_at = replay_events.last().map_or(resume_at, Sequenced::sequence);
@@ -86,6 +97,9 @@ impl Events {
pub enum Error {
Database(#[from] sqlx::Error),
Name(#[from] name::Error),
+ Ecdsa(#[from] p256::ecdsa::Error),
+ Pkcs8(#[from] p256::pkcs8::Error),
+ WebPush(#[from] web_push::WebPushError),
}
impl From<user::repo::LoadError> for Error {
@@ -107,3 +121,15 @@ impl From<conversation::repo::LoadError> for Error {
}
}
}
+
+impl From<vapid::repo::Error> for Error {
+ fn from(error: vapid::repo::Error) -> Self {
+ use vapid::repo::Error;
+ match error {
+ Error::Database(error) => error.into(),
+ Error::Ecdsa(error) => error.into(),
+ Error::Pkcs8(error) => error.into(),
+ Error::WebPush(error) => error.into(),
+ }
+ }
+}
diff --git a/src/event/handlers/stream/mod.rs b/src/event/handlers/stream/mod.rs
index 63bfff3..8b89c31 100644
--- a/src/event/handlers/stream/mod.rs
+++ b/src/event/handlers/stream/mod.rs
@@ -18,8 +18,8 @@ use crate::{
#[cfg(test)]
mod test;
-pub async fn handler(
- State(app): State<App>,
+pub async fn handler<P>(
+ State(app): State<App<P>>,
identity: Identity,
last_event_id: Option<LastEventId<Sequence>>,
Query(query): Query<QueryParams>,
diff --git a/src/event/handlers/stream/test/mod.rs b/src/event/handlers/stream/test/mod.rs
index 3bc634f..c3a6ce6 100644
--- a/src/event/handlers/stream/test/mod.rs
+++ b/src/event/handlers/stream/test/mod.rs
@@ -4,5 +4,6 @@ mod message;
mod resume;
mod setup;
mod token;
+mod vapid;
use super::{QueryParams, Response, handler};
diff --git a/src/event/handlers/stream/test/vapid.rs b/src/event/handlers/stream/test/vapid.rs
new file mode 100644
index 0000000..dbc3929
--- /dev/null
+++ b/src/event/handlers/stream/test/vapid.rs
@@ -0,0 +1,111 @@
+use axum::extract::State;
+use axum_extra::extract::Query;
+use futures::StreamExt as _;
+
+use crate::test::{fixtures, fixtures::future::Expect as _};
+
+#[tokio::test]
+async fn live_vapid_key_changes() {
+ // Set up the context
+ let app = fixtures::scratch_app().await;
+ let resume_point = fixtures::boot::resume_point(&app).await;
+
+ // Subscribe to events
+
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+ let super::Response(events) = super::handler(
+ State(app.clone()),
+ subscriber,
+ None,
+ Query(super::QueryParams { resume_point }),
+ )
+ .await
+ .expect("subscribe never fails");
+
+ // Rotate the VAPID key
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the vapid key always succeeds");
+
+ // Verify that there's a key rotation event
+
+ events
+ .filter_map(fixtures::event::stream::vapid)
+ .filter_map(fixtures::event::stream::vapid::changed)
+ .next()
+ .expect_some("a vapid key change event is sent")
+ .await;
+}
+
+#[tokio::test]
+async fn stored_vapid_key_changes() {
+ // Set up the context
+ let app = fixtures::scratch_app().await;
+ let resume_point = fixtures::boot::resume_point(&app).await;
+
+ // Rotate the VAPID key
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the vapid key always succeeds");
+
+ // Subscribe to events
+
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+ let super::Response(events) = super::handler(
+ State(app.clone()),
+ subscriber,
+ None,
+ Query(super::QueryParams { resume_point }),
+ )
+ .await
+ .expect("subscribe never fails");
+
+ // Verify that there's a key rotation event
+
+ events
+ .filter_map(fixtures::event::stream::vapid)
+ .filter_map(fixtures::event::stream::vapid::changed)
+ .next()
+ .expect_some("a vapid key change event is sent")
+ .await;
+}
+
+#[tokio::test]
+async fn no_past_vapid_key_changes() {
+ // Set up the context
+ let app = fixtures::scratch_app().await;
+
+ // Rotate the VAPID key
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the vapid key always succeeds");
+
+ // Subscribe to events
+
+ let resume_point = fixtures::boot::resume_point(&app).await;
+
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+ let super::Response(events) = super::handler(
+ State(app.clone()),
+ subscriber,
+ None,
+ Query(super::QueryParams { resume_point }),
+ )
+ .await
+ .expect("subscribe never fails");
+
+ // Verify that there's a key rotation event
+
+ events
+ .filter_map(fixtures::event::stream::vapid)
+ .filter_map(fixtures::event::stream::vapid::changed)
+ .next()
+ .expect_wait("a vapid key change event is not sent")
+ .await;
+}
diff --git a/src/event/mod.rs b/src/event/mod.rs
index f41dc9c..83b0ce7 100644
--- a/src/event/mod.rs
+++ b/src/event/mod.rs
@@ -2,7 +2,7 @@ use std::time::Duration;
use axum::response::sse::{self, KeepAlive};
-use crate::{conversation, message, user};
+use crate::{conversation, message, user, vapid};
pub mod app;
mod broadcaster;
@@ -22,6 +22,7 @@ pub enum Event {
User(user::Event),
Conversation(conversation::Event),
Message(message::Event),
+ Vapid(vapid::Event),
}
// Serialized representation is intended to look like the serialized representation of `Event`,
@@ -40,6 +41,7 @@ impl Sequenced for Event {
Self::User(event) => event.instant(),
Self::Conversation(event) => event.instant(),
Self::Message(event) => event.instant(),
+ Self::Vapid(event) => event.instant(),
}
}
}
@@ -62,6 +64,12 @@ impl From<message::Event> for Event {
}
}
+impl From<vapid::Event> for Event {
+ fn from(event: vapid::Event) -> Self {
+ Self::Vapid(event)
+ }
+}
+
impl Heartbeat {
// The following values are a first-rough-guess attempt to balance noticing connection problems
// quickly with managing the (modest) costs of delivering and processing heartbeats. Feel
diff --git a/src/expire.rs b/src/expire.rs
index 4177a53..c3b0117 100644
--- a/src/expire.rs
+++ b/src/expire.rs
@@ -7,8 +7,8 @@ use axum::{
use crate::{app::App, clock::RequestedAt, error::Internal};
// Expires messages and conversations before each request.
-pub async fn middleware(
- State(app): State<App>,
+pub async fn middleware<P>(
+ State(app): State<App<P>>,
RequestedAt(expired_at): RequestedAt,
req: Request,
next: Next,
diff --git a/src/lib.rs b/src/lib.rs
index f05cce3..38e6bc5 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -20,6 +20,7 @@ mod message;
mod name;
mod normalize;
mod password;
+mod push;
mod routes;
mod setup;
#[cfg(test)]
@@ -28,3 +29,4 @@ mod token;
mod ui;
mod umask;
mod user;
+mod vapid;
diff --git a/src/login/app.rs b/src/login/app.rs
index a2f9636..8cc8cd0 100644
--- a/src/login/app.rs
+++ b/src/login/app.rs
@@ -6,6 +6,7 @@ use crate::{
login::{self, Login, repo::Provider as _},
name::{self, Name},
password::Password,
+ push::repo::Provider as _,
token::{Broadcaster, Event as TokenEvent, Secret, Token, repo::Provider as _},
};
@@ -76,6 +77,7 @@ impl Logins {
let mut tx = self.db.begin().await?;
tx.logins().set_password(&login, &to_hash).await?;
+ tx.push().unsubscribe_login(&login).await?;
let revoked = tx.tokens().revoke_all(&login).await?;
tx.tokens().create(&token, &secret).await?;
tx.commit().await?;
diff --git a/src/push/app.rs b/src/push/app.rs
new file mode 100644
index 0000000..56b9a02
--- /dev/null
+++ b/src/push/app.rs
@@ -0,0 +1,178 @@
+use futures::future::join_all;
+use itertools::Itertools as _;
+use p256::ecdsa::VerifyingKey;
+use sqlx::SqlitePool;
+use web_push::{
+ ContentEncoding, PartialVapidSignatureBuilder, SubscriptionInfo, WebPushClient, WebPushError,
+ WebPushMessage, WebPushMessageBuilder,
+};
+
+use super::repo::Provider as _;
+use crate::{login::Login, token::extract::Identity, vapid, vapid::repo::Provider as _};
+
+pub struct Push<P> {
+ db: SqlitePool,
+ webpush: P,
+}
+
+impl<P> Push<P> {
+ pub const fn new(db: SqlitePool, webpush: P) -> Self {
+ Self { db, webpush }
+ }
+
+ pub async fn subscribe(
+ &self,
+ subscriber: &Identity,
+ subscription: &SubscriptionInfo,
+ vapid: &VerifyingKey,
+ ) -> Result<(), SubscribeError> {
+ let mut tx = self.db.begin().await?;
+
+ let current = tx.vapid().current().await?;
+ if vapid != &current.key {
+ return Err(SubscribeError::StaleVapidKey(current.key));
+ }
+
+ match tx.push().create(&subscriber.token, subscription).await {
+ Ok(()) => (),
+ Err(err) => {
+ if let Some(err) = err.as_database_error()
+ && err.is_unique_violation()
+ {
+ let current = tx
+ .push()
+ .by_endpoint(&subscriber.login, &subscription.endpoint)
+ .await?;
+ // If we already have a subscription for this endpoint, with _different_
+ // parameters, then this is a client error. They shouldn't reuse endpoint URLs,
+ // per the various RFCs.
+ //
+ // However, if we have a subscription for this endpoint with the same parameters
+ // then we accept it and silently do nothing. This may happen if, for example,
+ // the subscribe request is retried due to a network interruption where it's
+ // not clear whether the original request succeeded.
+ if &current != subscription {
+ return Err(SubscribeError::Duplicate);
+ }
+ } else {
+ return Err(SubscribeError::Database(err));
+ }
+ }
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+}
+
+impl<P> Push<P>
+where
+ P: WebPushClient,
+{
+ fn prepare_ping(
+ signer: &PartialVapidSignatureBuilder,
+ subscription: &SubscriptionInfo,
+ ) -> Result<WebPushMessage, WebPushError> {
+ let signature = signer.clone().add_sub_info(subscription).build()?;
+
+ let payload = "ping".as_bytes();
+
+ let mut message = WebPushMessageBuilder::new(subscription);
+ message.set_payload(ContentEncoding::Aes128Gcm, payload);
+ message.set_vapid_signature(signature);
+ let message = message.build()?;
+
+ Ok(message)
+ }
+
+ pub async fn ping(&self, recipient: &Login) -> Result<(), PushError> {
+ let mut tx = self.db.begin().await?;
+
+ let signer = tx.vapid().signer().await?;
+ let subscriptions = tx.push().by_login(recipient).await?;
+
+ let pings: Vec<_> = subscriptions
+ .into_iter()
+ .map(|sub| Self::prepare_ping(&signer, &sub).map(|message| (sub, message)))
+ .try_collect()?;
+
+ let deliveries = pings
+ .into_iter()
+ .map(async |(sub, message)| (sub, self.webpush.send(message).await));
+
+ let failures: Vec<_> = join_all(deliveries)
+ .await
+ .into_iter()
+ .filter_map(|(sub, result)| result.err().map(|err| (sub, err)))
+ .collect();
+
+ if !failures.is_empty() {
+ for (sub, err) in &failures {
+ match err {
+ // I _think_ this is the complete set of permanent failures. See
+ // <https://docs.rs/web-push/latest/web_push/enum.WebPushError.html> for a complete
+ // list.
+ WebPushError::Unauthorized(_)
+ | WebPushError::InvalidUri
+ | WebPushError::EndpointNotValid(_)
+ | WebPushError::EndpointNotFound(_)
+ | WebPushError::InvalidCryptoKeys
+ | WebPushError::MissingCryptoKeys => {
+ tx.push().unsubscribe(sub).await?;
+ }
+ _ => (),
+ }
+ }
+
+ return Err(PushError::Delivery(
+ failures.into_iter().map(|(_, err)| err).collect(),
+ ));
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum SubscribeError {
+ #[error(transparent)]
+ Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Vapid(#[from] vapid::repo::Error),
+ #[error("subscription created with stale VAPID key")]
+ StaleVapidKey(VerifyingKey),
+ #[error("subscription already exists for endpoint")]
+ // The endpoint URL is not included in the error, as it is a bearer credential in its own right
+ // and we want to limit its proliferation. The only intended recipient of this message is the
+ // client, which already knows the endpoint anyways and doesn't need us to tell them.
+ Duplicate,
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum PushError {
+ #[error(transparent)]
+ Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Ecdsa(#[from] p256::ecdsa::Error),
+ #[error(transparent)]
+ Pkcs8(#[from] p256::pkcs8::Error),
+ #[error(transparent)]
+ WebPush(#[from] WebPushError),
+ #[error("push message delivery failures: {0:?}")]
+ Delivery(Vec<WebPushError>),
+}
+
+impl From<vapid::repo::Error> for PushError {
+ fn from(error: vapid::repo::Error) -> Self {
+ use vapid::repo::Error;
+ match error {
+ Error::Database(error) => error.into(),
+ Error::Ecdsa(error) => error.into(),
+ Error::Pkcs8(error) => error.into(),
+ Error::WebPush(error) => error.into(),
+ }
+ }
+}
diff --git a/src/push/handlers/mod.rs b/src/push/handlers/mod.rs
new file mode 100644
index 0000000..bb58774
--- /dev/null
+++ b/src/push/handlers/mod.rs
@@ -0,0 +1,5 @@
+mod ping;
+mod subscribe;
+
+pub use ping::handler as ping;
+pub use subscribe::handler as subscribe;
diff --git a/src/push/handlers/ping/mod.rs b/src/push/handlers/ping/mod.rs
new file mode 100644
index 0000000..db828fa
--- /dev/null
+++ b/src/push/handlers/ping/mod.rs
@@ -0,0 +1,23 @@
+use axum::{Json, extract::State, http::StatusCode};
+use web_push::WebPushClient;
+
+use crate::{error::Internal, push::app::Push, token::extract::Identity};
+
+#[cfg(test)]
+mod test;
+
+#[derive(serde::Deserialize)]
+pub struct Request {}
+
+pub async fn handler<P>(
+ State(push): State<Push<P>>,
+ identity: Identity,
+ Json(_): Json<Request>,
+) -> Result<StatusCode, Internal>
+where
+ P: WebPushClient,
+{
+ push.ping(&identity.login).await?;
+
+ Ok(StatusCode::ACCEPTED)
+}
diff --git a/src/push/handlers/ping/test.rs b/src/push/handlers/ping/test.rs
new file mode 100644
index 0000000..5725131
--- /dev/null
+++ b/src/push/handlers/ping/test.rs
@@ -0,0 +1,40 @@
+use axum::{
+ extract::{Json, State},
+ http::StatusCode,
+};
+
+use crate::test::fixtures;
+
+#[tokio::test]
+async fn ping_without_subscriptions() {
+ let app = fixtures::scratch_app().await;
+
+ let recipient = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ let response = super::handler(State(app.push()), recipient, Json(super::Request {}))
+ .await
+ .expect("sending a ping with no subscriptions always succeeds");
+
+ assert_eq!(StatusCode::ACCEPTED, response);
+
+ assert!(app.webpush().sent().is_empty());
+}
+
+// More complete testing requires that we figure out how to generate working p256 ECDH keys for
+// testing _with_, as `web_push` will actually parse and use those keys even if push messages are
+// ultimately never serialized or sent over HTTP.
+//
+// Tests that are missing:
+//
+// * Verify that subscribing and sending a ping causes a ping to be delivered to that subscription.
+// * Verify that two subscriptions both get pings.
+// * Verify that other users' subscriptions are not pinged.
+// * Verify that a ping that causes a permanent error causes the subscription to be deleted.
+// * Verify that a ping that causes a non-permanent error does not cause the subscription to be
+// deleted.
+// * Verify that a failure on one subscription doesn't affect delivery on other subscriptions.
diff --git a/src/push/handlers/subscribe/mod.rs b/src/push/handlers/subscribe/mod.rs
new file mode 100644
index 0000000..a1a5899
--- /dev/null
+++ b/src/push/handlers/subscribe/mod.rs
@@ -0,0 +1,94 @@
+use axum::{
+ extract::{Json, State},
+ http::StatusCode,
+ response::{IntoResponse, Response},
+};
+use p256::ecdsa::VerifyingKey;
+use web_push::SubscriptionInfo;
+
+use crate::{
+ error::Internal,
+ push::{app, app::Push},
+ token::extract::Identity,
+};
+
+#[cfg(test)]
+mod test;
+
+#[derive(Clone, serde::Deserialize)]
+pub struct Request {
+ subscription: Subscription,
+ #[serde(with = "crate::vapid::ser::key")]
+ vapid: VerifyingKey,
+}
+
+// This structure is described in <https://w3c.github.io/push-api/#dom-pushsubscription-tojson>.
+#[derive(Clone, serde::Deserialize)]
+pub struct Subscription {
+ endpoint: String,
+ keys: Keys,
+}
+
+// This structure is described in <https://w3c.github.io/push-api/#dom-pushsubscription-tojson>.
+#[derive(Clone, serde::Deserialize)]
+pub struct Keys {
+ p256dh: String,
+ auth: String,
+}
+
+pub async fn handler<P>(
+ State(push): State<Push<P>>,
+ identity: Identity,
+ Json(request): Json<Request>,
+) -> Result<StatusCode, Error> {
+ let Request {
+ subscription,
+ vapid,
+ } = request;
+
+ push.subscribe(&identity, &subscription.into(), &vapid)
+ .await?;
+
+ Ok(StatusCode::CREATED)
+}
+
+impl From<Subscription> for SubscriptionInfo {
+ fn from(request: Subscription) -> Self {
+ let Subscription {
+ endpoint,
+ keys: Keys { p256dh, auth },
+ } = request;
+ SubscriptionInfo::new(endpoint, p256dh, auth)
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub struct Error(#[from] app::SubscribeError);
+
+impl IntoResponse for Error {
+ fn into_response(self) -> Response {
+ let Self(err) = self;
+
+ match err {
+ app::SubscribeError::StaleVapidKey(key) => {
+ let body = StaleVapidKey {
+ message: err.to_string(),
+ key,
+ };
+ (StatusCode::BAD_REQUEST, Json(body)).into_response()
+ }
+ app::SubscribeError::Duplicate => {
+ (StatusCode::CONFLICT, err.to_string()).into_response()
+ }
+ other => Internal::from(other).into_response(),
+ }
+ }
+}
+
+#[derive(serde::Serialize)]
+struct StaleVapidKey {
+ message: String,
+ #[serde(with = "crate::vapid::ser::key")]
+ key: VerifyingKey,
+}
diff --git a/src/push/handlers/subscribe/test.rs b/src/push/handlers/subscribe/test.rs
new file mode 100644
index 0000000..b72624d
--- /dev/null
+++ b/src/push/handlers/subscribe/test.rs
@@ -0,0 +1,236 @@
+use axum::{
+ extract::{Json, State},
+ http::StatusCode,
+};
+
+use crate::{
+ push::app::SubscribeError,
+ test::{fixtures, fixtures::event},
+};
+
+#[tokio::test]
+async fn accepts_new_subscription() {
+ let app = fixtures::scratch_app().await;
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ // Issue a VAPID key.
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what that VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Create a dummy subscription with that key.
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("test-p256dh-value"),
+ auth: String::from("test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ let response = super::handler(State(app.push()), subscriber, Json(request))
+ .await
+ .expect("test request will succeed on a fresh app");
+
+ // Check that the response looks as expected.
+
+ assert_eq!(StatusCode::CREATED, response);
+}
+
+#[tokio::test]
+async fn accepts_repeat_subscription() {
+ let app = fixtures::scratch_app().await;
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ // Issue a VAPID key.
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what that VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Create a dummy subscription with that key.
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("test-p256dh-value"),
+ auth: String::from("test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ let response = super::handler(State(app.push()), subscriber.clone(), Json(request.clone()))
+ .await
+ .expect("test request will succeed on a fresh app");
+
+ // Check that the response looks as expected.
+
+ assert_eq!(StatusCode::CREATED, response);
+
+ // Repeat the request
+
+ let response = super::handler(State(app.push()), subscriber, Json(request))
+ .await
+ .expect("test request will succeed twice on a fresh app");
+
+ // Check that the second response also looks as expected.
+
+ assert_eq!(StatusCode::CREATED, response);
+}
+
+#[tokio::test]
+async fn rejects_duplicate_subscription() {
+ let app = fixtures::scratch_app().await;
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ // Issue a VAPID key.
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what that VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Create a dummy subscription with that key.
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("test-p256dh-value"),
+ auth: String::from("test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ super::handler(State(app.push()), subscriber.clone(), Json(request))
+ .await
+ .expect("test request will succeed on a fresh app");
+
+ // Repeat the request with different keys
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("different-test-p256dh-value"),
+ auth: String::from("different-test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ let response = super::handler(State(app.push()), subscriber, Json(request))
+ .await
+ .expect_err("request with duplicate endpoint should fail");
+
+ // Make sure we got the error we expected.
+
+ assert!(matches!(response, super::Error(SubscribeError::Duplicate)));
+}
+
+#[tokio::test]
+async fn rejects_stale_vapid_key() {
+ let app = fixtures::scratch_app().await;
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ // Issue a VAPID key.
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what that VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Change the VAPID key.
+
+ app.vapid()
+ .rotate_key()
+ .await
+ .expect("key rotation always succeeds");
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what the new VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let fresh_vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Create a dummy subscription with the original key.
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("test-p256dh-value"),
+ auth: String::from("test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ let response = super::handler(State(app.push()), subscriber, Json(request))
+ .await
+ .expect_err("test request has a stale vapid key");
+
+ // Check that the response looks as expected.
+
+ assert!(matches!(
+ response,
+ super::Error(SubscribeError::StaleVapidKey(key)) if key == fresh_vapid.key
+ ));
+}
diff --git a/src/push/mod.rs b/src/push/mod.rs
new file mode 100644
index 0000000..1394ea4
--- /dev/null
+++ b/src/push/mod.rs
@@ -0,0 +1,3 @@
+pub mod app;
+pub mod handlers;
+pub mod repo;
diff --git a/src/push/repo.rs b/src/push/repo.rs
new file mode 100644
index 0000000..4183489
--- /dev/null
+++ b/src/push/repo.rs
@@ -0,0 +1,149 @@
+use sqlx::{Sqlite, SqliteConnection, Transaction};
+use web_push::SubscriptionInfo;
+
+use crate::{login::Login, token::Token};
+
+pub trait Provider {
+ fn push(&mut self) -> Push<'_>;
+}
+
+impl Provider for Transaction<'_, Sqlite> {
+ fn push(&mut self) -> Push<'_> {
+ Push(self)
+ }
+}
+
+pub struct Push<'t>(&'t mut SqliteConnection);
+
+impl Push<'_> {
+ pub async fn create(
+ &mut self,
+ token: &Token,
+ subscription: &SubscriptionInfo,
+ ) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ insert into push_subscription (token, endpoint, p256dh, auth)
+ values ($1, $2, $3, $4)
+ "#,
+ token.id,
+ subscription.endpoint,
+ subscription.keys.p256dh,
+ subscription.keys.auth,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn by_login(&mut self, login: &Login) -> Result<Vec<SubscriptionInfo>, sqlx::Error> {
+ sqlx::query!(
+ r#"
+ select
+ subscription.endpoint,
+ subscription.p256dh,
+ subscription.auth
+ from push_subscription as subscription
+ join token on subscription.token = token.id
+ where token.login = $1
+ "#,
+ login.id,
+ )
+ .map(|row| SubscriptionInfo::new(row.endpoint, row.p256dh, row.auth))
+ .fetch_all(&mut *self.0)
+ .await
+ }
+
+ pub async fn by_endpoint(
+ &mut self,
+ subscriber: &Login,
+ endpoint: &str,
+ ) -> Result<SubscriptionInfo, sqlx::Error> {
+ let row = sqlx::query!(
+ r#"
+ select
+ subscription.endpoint,
+ subscription.p256dh,
+ subscription.auth
+ from push_subscription as subscription
+ join token on subscription.token = token.id
+ join login as subscriber on token.login = subscriber.id
+ where subscriber.id = $1
+ and subscription.endpoint = $2
+ "#,
+ subscriber.id,
+ endpoint,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ let info = SubscriptionInfo::new(row.endpoint, row.p256dh, row.auth);
+
+ Ok(info)
+ }
+
+ pub async fn unsubscribe(
+ &mut self,
+ subscription: &SubscriptionInfo,
+ ) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ delete from push_subscription
+ where endpoint = $1
+ "#,
+ subscription.endpoint,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn unsubscribe_token(&mut self, token: &Token) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ delete from push_subscription
+ where token = $1
+ "#,
+ token.id,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn unsubscribe_login(&mut self, login: &Login) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ with tokens as (
+ select id from token
+ where login = $1
+ )
+ delete from push_subscription
+ where token in tokens
+ "#,
+ login.id,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ // Unsubscribe logic for token expiry lives in the `tokens` repository, for maintenance reasons.
+
+ pub async fn clear(&mut self) -> Result<(), sqlx::Error> {
+ // We assume that _all_ stored subscriptions are for a VAPID key we're about to delete.
+ sqlx::query!(
+ r#"
+ delete from push_subscription
+ "#,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+}
diff --git a/src/routes.rs b/src/routes.rs
index b848afb..1c07e78 100644
--- a/src/routes.rs
+++ b/src/routes.rs
@@ -3,10 +3,16 @@ use axum::{
response::Redirect,
routing::{delete, get, post},
};
+use web_push::WebPushClient;
-use crate::{app::App, boot, conversation, event, expire, invite, login, message, setup, ui};
+use crate::{
+ app::App, boot, conversation, event, expire, invite, login, message, push, setup, ui, vapid,
+};
-pub fn routes(app: &App) -> Router<App> {
+pub fn routes<P>(app: &App<P>) -> Router<App<P>>
+where
+ P: WebPushClient + Clone + Send + Sync + 'static,
+{
// UI routes that can be accessed before the administrator completes setup.
let ui_bootstrap = Router::new()
.route("/{*path}", get(ui::handlers::asset))
@@ -44,6 +50,8 @@ pub fn routes(app: &App) -> Router<App> {
.route("/api/invite/{invite}", get(invite::handlers::get))
.route("/api/invite/{invite}", post(invite::handlers::accept))
.route("/api/messages/{message}", delete(message::handlers::delete))
+ .route("/api/push/ping", post(push::handlers::ping))
+ .route("/api/push/subscribe", post(push::handlers::subscribe))
.route("/api/password", post(login::handlers::change_password))
// Run expiry whenever someone accesses the API. This was previously a blanket middleware
// affecting the whole service, but loading the client makes a several requests before the
@@ -56,6 +64,10 @@ pub fn routes(app: &App) -> Router<App> {
app.clone(),
expire::middleware,
))
+ .route_layer(middleware::from_fn_with_state(
+ app.clone(),
+ vapid::middleware,
+ ))
.route_layer(setup::Required(app.clone()));
[
diff --git a/src/test/fixtures/event/mod.rs b/src/test/fixtures/event/mod.rs
index 08b17e7..f8651ba 100644
--- a/src/test/fixtures/event/mod.rs
+++ b/src/test/fixtures/event/mod.rs
@@ -23,6 +23,13 @@ pub fn user(event: Event) -> Option<crate::user::Event> {
}
}
+pub fn vapid(event: Event) -> Option<crate::vapid::Event> {
+ match event {
+ Event::Vapid(event) => Some(event),
+ _ => None,
+ }
+}
+
pub mod conversation {
use crate::conversation::{Event, event};
@@ -72,3 +79,17 @@ pub mod user {
}
}
}
+
+pub mod vapid {
+ use crate::vapid::{Event, event};
+
+ // This could be defined as `-> event::Changed`. However, I want the interface to be consistent
+ // with the event stream transformers for other types, and we'd have to refactor the return type
+ // to `-> Option<event::Changed>` the instant VAPID keys sprout a second event.
+ #[allow(clippy::unnecessary_wraps)]
+ pub fn changed(event: Event) -> Option<event::Changed> {
+ match event {
+ Event::Changed(changed) => Some(changed),
+ }
+ }
+}
diff --git a/src/test/fixtures/event/stream.rs b/src/test/fixtures/event/stream.rs
index 5b3621d..bb83d0d 100644
--- a/src/test/fixtures/event/stream.rs
+++ b/src/test/fixtures/event/stream.rs
@@ -14,6 +14,10 @@ pub fn user(event: Event) -> Ready<Option<crate::user::Event>> {
future::ready(event::user(event))
}
+pub fn vapid(event: Event) -> Ready<Option<crate::vapid::Event>> {
+ future::ready(event::vapid(event))
+}
+
pub mod conversation {
use std::future::{self, Ready};
@@ -60,3 +64,16 @@ pub mod user {
future::ready(user::created(event))
}
}
+
+pub mod vapid {
+ use std::future::{self, Ready};
+
+ use crate::{
+ test::fixtures::event::vapid,
+ vapid::{Event, event},
+ };
+
+ pub fn changed(event: Event) -> Ready<Option<event::Changed>> {
+ future::ready(vapid::changed(event))
+ }
+}
diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs
index 20929f9..adc3e73 100644
--- a/src/test/fixtures/identity.rs
+++ b/src/test/fixtures/identity.rs
@@ -14,7 +14,7 @@ use crate::{
},
};
-pub async fn create(app: &App, created_at: &RequestedAt) -> Identity {
+pub async fn create<P>(app: &App<P>, created_at: &RequestedAt) -> Identity {
let credentials = fixtures::user::create_with_password(app, created_at).await;
logged_in(app, &credentials, created_at).await
}
diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs
index d9aca81..839a412 100644
--- a/src/test/fixtures/login.rs
+++ b/src/test/fixtures/login.rs
@@ -5,7 +5,7 @@ use crate::{
test::fixtures::user::{propose, propose_name},
};
-pub async fn create(app: &App, created_at: &DateTime) -> Login {
+pub async fn create<P>(app: &App<P>, created_at: &DateTime) -> Login {
let (name, password) = propose();
app.users()
.create(&name, &password, created_at)
diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs
index 3d69cfa..53bf31b 100644
--- a/src/test/fixtures/mod.rs
+++ b/src/test/fixtures/mod.rs
@@ -1,6 +1,6 @@
use chrono::{TimeDelta, Utc};
-use crate::{app::App, clock::RequestedAt, db};
+use crate::{app::App, clock::RequestedAt, db, test::webpush::Client};
pub mod boot;
pub mod conversation;
@@ -13,11 +13,11 @@ pub mod login;
pub mod message;
pub mod user;
-pub async fn scratch_app() -> App {
+pub async fn scratch_app() -> App<Client> {
let pool = db::prepare("sqlite::memory:", "sqlite::memory:")
.await
.expect("setting up in-memory sqlite database");
- App::from(pool)
+ App::from(pool, Client::new())
}
pub fn now() -> RequestedAt {
diff --git a/src/test/fixtures/user.rs b/src/test/fixtures/user.rs
index d4d8db4..3ad4436 100644
--- a/src/test/fixtures/user.rs
+++ b/src/test/fixtures/user.rs
@@ -3,7 +3,7 @@ use uuid::Uuid;
use crate::{app::App, clock::RequestedAt, login::Login, name::Name, password::Password};
-pub async fn create_with_password(app: &App, created_at: &RequestedAt) -> (Name, Password) {
+pub async fn create_with_password<P>(app: &App<P>, created_at: &RequestedAt) -> (Name, Password) {
let (name, password) = propose();
let user = app
.users()
@@ -14,7 +14,7 @@ pub async fn create_with_password(app: &App, created_at: &RequestedAt) -> (Name,
(user.name, password)
}
-pub async fn create(app: &App, created_at: &RequestedAt) -> Login {
+pub async fn create<P>(app: &App<P>, created_at: &RequestedAt) -> Login {
super::login::create(app, created_at).await
}
diff --git a/src/test/mod.rs b/src/test/mod.rs
index ebbbfef..f798b9c 100644
--- a/src/test/mod.rs
+++ b/src/test/mod.rs
@@ -1,2 +1,3 @@
pub mod fixtures;
pub mod verify;
+pub mod webpush;
diff --git a/src/test/webpush.rs b/src/test/webpush.rs
new file mode 100644
index 0000000..c86d03f
--- /dev/null
+++ b/src/test/webpush.rs
@@ -0,0 +1,37 @@
+use std::{
+ mem,
+ sync::{Arc, Mutex},
+};
+
+use web_push::{WebPushClient, WebPushError, WebPushMessage};
+
+#[derive(Clone)]
+pub struct Client {
+ sent: Arc<Mutex<Vec<WebPushMessage>>>,
+}
+
+impl Client {
+ pub fn new() -> Self {
+ Self {
+ sent: Arc::default(),
+ }
+ }
+
+ // Clears the list of sent messages (for all clones of this Client) when called, because we
+ // can't clone `WebPushMessage`s so we either need to move them or try to reconstruct them,
+ // either of which sucks but moving them sucks less.
+ pub fn sent(&self) -> Vec<WebPushMessage> {
+ let mut sent = self.sent.lock().unwrap();
+ mem::replace(&mut *sent, Vec::new())
+ }
+}
+
+#[async_trait::async_trait]
+impl WebPushClient for Client {
+ async fn send(&self, message: WebPushMessage) -> Result<(), WebPushError> {
+ let mut sent = self.sent.lock().unwrap();
+ sent.push(message);
+
+ Ok(())
+ }
+}
diff --git a/src/token/app.rs b/src/token/app.rs
index 332473d..4a08877 100644
--- a/src/token/app.rs
+++ b/src/token/app.rs
@@ -10,7 +10,7 @@ use super::{
extract::Identity,
repo::{self, Provider as _},
};
-use crate::{clock::DateTime, db::NotFound as _, name};
+use crate::{clock::DateTime, db::NotFound as _, name, push::repo::Provider as _};
pub struct Tokens {
db: SqlitePool,
@@ -112,6 +112,7 @@ impl Tokens {
pub async fn logout(&self, token: &Token) -> Result<(), ValidateError> {
let mut tx = self.db.begin().await?;
+ tx.push().unsubscribe_token(token).await?;
tx.tokens().revoke(token).await?;
tx.commit().await?;
diff --git a/src/token/repo/token.rs b/src/token/repo/token.rs
index 52a3987..33c33af 100644
--- a/src/token/repo/token.rs
+++ b/src/token/repo/token.rs
@@ -89,6 +89,23 @@ impl Tokens<'_> {
// Expire and delete all tokens that haven't been used more recently than
// `expire_at`.
pub async fn expire(&mut self, expire_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> {
+ // This lives here, rather than in the `push` repository, to ensure that the criteria for
+ // stale tokens don't drift apart between the two queries. That would be a larger risk if
+ // the queries lived in very separate parts of the codebase.
+ sqlx::query!(
+ r#"
+ with stale_tokens as (
+ select id from token
+ where last_used_at < $1
+ )
+ delete from push_subscription
+ where token in stale_tokens
+ "#,
+ expire_at,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
let tokens = sqlx::query_scalar!(
r#"
delete
diff --git a/src/vapid/app.rs b/src/vapid/app.rs
new file mode 100644
index 0000000..9949aa5
--- /dev/null
+++ b/src/vapid/app.rs
@@ -0,0 +1,117 @@
+use chrono::TimeDelta;
+use sqlx::SqlitePool;
+
+use super::{History, repo, repo::Provider as _};
+use crate::{
+ clock::DateTime,
+ db::NotFound as _,
+ event::{Broadcaster, Sequence, repo::Provider},
+ push::repo::Provider as _,
+};
+
+pub struct Vapid {
+ db: SqlitePool,
+ events: Broadcaster,
+}
+
+impl Vapid {
+ pub const fn new(db: SqlitePool, events: Broadcaster) -> Self {
+ Self { db, events }
+ }
+
+ pub async fn rotate_key(&self) -> Result<(), sqlx::Error> {
+ let mut tx = self.db.begin().await?;
+ // This is called from a separate CLI utility (see `cli.rs`), and we _can't_ deliver events
+ // to active clients from another process, so don't do anything that would require us to
+ // send events, like generating a new key.
+ //
+ // Instead, the server's next `refresh_key` call will generate a key and notify clients
+ // of the change. All we have to do is remove the existing key, so that the server can know
+ // to do so.
+ tx.vapid().clear().await?;
+ // Delete outstanding subscriptions for the existing VAPID key, as well. They're
+ // unserviceable once we lose the key. Clients can resubscribe when they process the next
+ // key rotation event, which will be quite quickly once the running server notices that the
+ // VAPID key has been removed.
+ tx.push().clear().await?;
+ tx.commit().await?;
+
+ Ok(())
+ }
+
+ pub async fn refresh_key(&self, ensure_at: &DateTime) -> Result<(), Error> {
+ let mut tx = self.db.begin().await?;
+ let key = tx.vapid().current().await.optional()?;
+ if key.is_none() {
+ let changed_at = tx.sequence().next(ensure_at).await?;
+ let (key, secret) = History::begin(&changed_at);
+
+ tx.vapid().clear().await?;
+ tx.vapid().store_signing_key(&secret).await?;
+
+ let events = key.events().filter(Sequence::start_from(changed_at));
+ tx.vapid().record_events(events.clone()).await?;
+
+ tx.commit().await?;
+
+ self.events.broadcast_from(events);
+ } else if let Some(key) = key
+ // Somewhat arbitrarily, rotate keys every 30 days.
+ && key.older_than(ensure_at.to_owned() - TimeDelta::days(30))
+ {
+ // If you can think of a way to factor out this duplication, be my guest. I tried.
+ // The only approach I could think of mirrors `crate::user::create::Create`, encoding
+ // the process in a state machine made of types, and that's a very complex solution
+ // to a problem that doesn't seem to merit it. -o
+ let changed_at = tx.sequence().next(ensure_at).await?;
+ let (key, secret) = key.rotate(&changed_at);
+
+ // This will delete _all_ stored subscriptions. This is fine; they're all for the
+ // current VAPID key, and we won't be able to use them anyways once the key is rotated.
+ // We have no way to inform the push broker services of that, unfortunately.
+ tx.push().clear().await?;
+ tx.vapid().clear().await?;
+ tx.vapid().store_signing_key(&secret).await?;
+
+ // Refactoring constraint: this `events` iterator borrows `key`. Anything that moves
+ // `key` has to give it back, but it can't give both `key` back and an event iterator
+ // borrowing from `key` because Rust doesn't support types that borrow from other
+ // parts of themselves.
+ let events = key.events().filter(Sequence::start_from(changed_at));
+ tx.vapid().record_events(events.clone()).await?;
+
+ // Refactoring constraint: we _really_ want to commit the transaction before we send
+ // out events, so that anything acting on those events is guaranteed to see the state
+ // of the service at some point at or after the side effects of this. I'd also prefer
+ // to keep the commit in the same method that the transaction is begun in, for clarity.
+ tx.commit().await?;
+
+ self.events.broadcast_from(events);
+ }
+ // else, the key exists and is not stale. Don't bother allocating a sequence number, and
+ // in fact throw away the whole transaction.
+
+ Ok(())
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum Error {
+ Database(#[from] sqlx::Error),
+ Ecdsa(#[from] p256::ecdsa::Error),
+ Pkcs8(#[from] p256::pkcs8::Error),
+ WebPush(#[from] web_push::WebPushError),
+}
+
+impl From<repo::Error> for Error {
+ fn from(error: repo::Error) -> Self {
+ use repo::Error;
+ match error {
+ Error::Database(error) => error.into(),
+ Error::Ecdsa(error) => error.into(),
+ Error::Pkcs8(error) => error.into(),
+ Error::WebPush(error) => error.into(),
+ }
+ }
+}
diff --git a/src/vapid/event.rs b/src/vapid/event.rs
new file mode 100644
index 0000000..cf3be77
--- /dev/null
+++ b/src/vapid/event.rs
@@ -0,0 +1,37 @@
+use p256::ecdsa::VerifyingKey;
+
+use crate::event::{Instant, Sequenced};
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+#[serde(tag = "event", rename_all = "snake_case")]
+pub enum Event {
+ Changed(Changed),
+}
+
+impl Sequenced for Event {
+ fn instant(&self) -> Instant {
+ match self {
+ Self::Changed(event) => event.instant(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Changed {
+ #[serde(flatten)]
+ pub instant: Instant,
+ #[serde(with = "crate::vapid::ser::key")]
+ pub key: VerifyingKey,
+}
+
+impl From<Changed> for Event {
+ fn from(event: Changed) -> Self {
+ Self::Changed(event)
+ }
+}
+
+impl Sequenced for Changed {
+ fn instant(&self) -> Instant {
+ self.instant
+ }
+}
diff --git a/src/vapid/history.rs b/src/vapid/history.rs
new file mode 100644
index 0000000..42f062b
--- /dev/null
+++ b/src/vapid/history.rs
@@ -0,0 +1,55 @@
+use p256::ecdsa::{SigningKey, VerifyingKey};
+use rand::thread_rng;
+
+use super::event::{Changed, Event};
+use crate::{clock::DateTime, event::Instant};
+
+#[derive(Debug)]
+pub struct History {
+ pub key: VerifyingKey,
+ pub changed: Instant,
+}
+
+// Lifecycle interface
+impl History {
+ pub fn begin(changed: &Instant) -> (Self, SigningKey) {
+ let key = SigningKey::random(&mut thread_rng());
+ (
+ Self {
+ key: VerifyingKey::from(&key),
+ changed: *changed,
+ },
+ key,
+ )
+ }
+
+ // `self` _is_ unused here, clippy is right about that. This choice is deliberate, however - it
+ // makes it harder to inadvertently reuse a rotated key via its history, and it makes the
+ // lifecycle interface more obviously consistent between this and other History types.
+ #[allow(clippy::unused_self)]
+ pub fn rotate(self, changed: &Instant) -> (Self, SigningKey) {
+ Self::begin(changed)
+ }
+}
+
+// State interface
+impl History {
+ pub fn older_than(&self, when: DateTime) -> bool {
+ self.changed.at < when
+ }
+}
+
+// Events interface
+impl History {
+ pub fn events(&self) -> impl Iterator<Item = Event> + Clone {
+ [self.changed()].into_iter()
+ }
+
+ fn changed(&self) -> Event {
+ Changed {
+ key: self.key,
+ instant: self.changed,
+ }
+ .into()
+ }
+}
diff --git a/src/vapid/middleware.rs b/src/vapid/middleware.rs
new file mode 100644
index 0000000..3129aa7
--- /dev/null
+++ b/src/vapid/middleware.rs
@@ -0,0 +1,17 @@
+use axum::{
+ extract::{Request, State},
+ middleware::Next,
+ response::Response,
+};
+
+use crate::{clock::RequestedAt, error::Internal, vapid::app::Vapid};
+
+pub async fn middleware(
+ State(vapid): State<Vapid>,
+ RequestedAt(now): RequestedAt,
+ request: Request,
+ next: Next,
+) -> Result<Response, Internal> {
+ vapid.refresh_key(&now).await?;
+ Ok(next.run(request).await)
+}
diff --git a/src/vapid/mod.rs b/src/vapid/mod.rs
new file mode 100644
index 0000000..364f602
--- /dev/null
+++ b/src/vapid/mod.rs
@@ -0,0 +1,10 @@
+pub mod app;
+pub mod event;
+mod history;
+mod middleware;
+pub mod repo;
+pub mod ser;
+
+pub use event::Event;
+pub use history::History;
+pub use middleware::middleware;
diff --git a/src/vapid/repo.rs b/src/vapid/repo.rs
new file mode 100644
index 0000000..9db61e1
--- /dev/null
+++ b/src/vapid/repo.rs
@@ -0,0 +1,161 @@
+use std::io::Cursor;
+
+use p256::{
+ ecdsa::SigningKey,
+ pkcs8::{DecodePrivateKey as _, EncodePrivateKey as _, LineEnding},
+};
+use sqlx::{Sqlite, SqliteConnection, Transaction};
+use web_push::{PartialVapidSignatureBuilder, VapidSignatureBuilder};
+
+use super::{
+ History,
+ event::{Changed, Event},
+};
+use crate::{
+ clock::DateTime,
+ db::NotFound,
+ event::{Instant, Sequence},
+};
+
+pub trait Provider {
+ fn vapid(&mut self) -> Vapid<'_>;
+}
+
+impl Provider for Transaction<'_, Sqlite> {
+ fn vapid(&mut self) -> Vapid<'_> {
+ Vapid(self)
+ }
+}
+
+pub struct Vapid<'a>(&'a mut SqliteConnection);
+
+impl Vapid<'_> {
+ pub async fn record_events(
+ &mut self,
+ events: impl IntoIterator<Item = Event>,
+ ) -> Result<(), sqlx::Error> {
+ for event in events {
+ self.record_event(&event).await?;
+ }
+ Ok(())
+ }
+
+ pub async fn record_event(&mut self, event: &Event) -> Result<(), sqlx::Error> {
+ match event {
+ Event::Changed(changed) => self.record_changed(changed).await,
+ }
+ }
+
+ async fn record_changed(&mut self, changed: &Changed) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ insert into vapid_key (changed_at, changed_sequence)
+ values ($1, $2)
+ "#,
+ changed.instant.at,
+ changed.instant.sequence,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn clear(&mut self) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ delete from vapid_key
+ "#
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ sqlx::query!(
+ r#"
+ delete from vapid_signing_key
+ "#
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn store_signing_key(&mut self, key: &SigningKey) -> Result<(), Error> {
+ let key = key.to_pkcs8_pem(LineEnding::CRLF)?;
+ let key = key.as_str();
+ sqlx::query!(
+ r#"
+ insert into vapid_signing_key (key)
+ values ($1)
+ "#,
+ key,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn current(&mut self) -> Result<History, Error> {
+ let key = sqlx::query!(
+ r#"
+ select
+ key.changed_at as "changed_at: DateTime",
+ key.changed_sequence as "changed_sequence: Sequence",
+ signing.key
+ from vapid_key as key
+ join vapid_signing_key as signing
+ "#
+ )
+ .map(|row| {
+ let key = SigningKey::from_pkcs8_pem(&row.key)?;
+ let key = key.verifying_key().to_owned();
+
+ let changed = Instant::new(row.changed_at, row.changed_sequence);
+
+ Ok::<_, Error>(History { key, changed })
+ })
+ .fetch_one(&mut *self.0)
+ .await??;
+
+ Ok(key)
+ }
+
+ pub async fn signer(&mut self) -> Result<PartialVapidSignatureBuilder, Error> {
+ let key = sqlx::query_scalar!(
+ r#"
+ select key
+ from vapid_signing_key
+ "#
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+ let key = Cursor::new(&key);
+ let signer = VapidSignatureBuilder::from_pem_no_sub(key)?;
+
+ Ok(signer)
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum Error {
+ Ecdsa(#[from] p256::ecdsa::Error),
+ Pkcs8(#[from] p256::pkcs8::Error),
+ WebPush(#[from] web_push::WebPushError),
+ Database(#[from] sqlx::Error),
+}
+
+impl<T> NotFound for Result<T, Error> {
+ type Ok = T;
+ type Error = Error;
+
+ fn optional(self) -> Result<Option<T>, Error> {
+ match self {
+ Ok(value) => Ok(Some(value)),
+ Err(Error::Database(sqlx::Error::RowNotFound)) => Ok(None),
+ Err(other) => Err(other),
+ }
+ }
+}
diff --git a/src/vapid/ser.rs b/src/vapid/ser.rs
new file mode 100644
index 0000000..02c77e1
--- /dev/null
+++ b/src/vapid/ser.rs
@@ -0,0 +1,63 @@
+pub mod key {
+ use std::fmt;
+
+ use base64::{Engine as _, engine::general_purpose::URL_SAFE};
+ use p256::ecdsa::VerifyingKey;
+ use serde::{Deserializer, Serialize as _, de};
+
+ // This serialization - to a URL-safe base-64-encoded string and back - is based on my best
+ // understanding of RFC 8292 and the corresponding browser APIs. Particularly, it's based on
+ // section 3.2:
+ //
+ // > The "k" parameter includes an ECDSA public key [FIPS186] in uncompressed form [X9.62] that
+ // > is encoded using base64url encoding [RFC7515].
+ //
+ // <https://datatracker.ietf.org/doc/html/rfc8292#section-3.2>
+ //
+ // I believe this is also supported by MDN's explanation:
+ //
+ // > `applicationServerKey`
+ // >
+ // > A Base64-encoded string or ArrayBuffer containing an ECDSA P-256 public key that the push
+ // > server will use to authenticate your application server. If specified, all messages from
+ // > your application server must use the VAPID authentication scheme, and include a JWT signed
+ // > with the corresponding private key. This key IS NOT the same ECDH key that you use to
+ // > encrypt the data. For more information, see "Using VAPID with WebPush".
+ //
+ // <https://developer.mozilla.org/en-US/docs/Web/API/PushManager/subscribe#applicationserverkey>
+
+ pub fn serialize<S>(key: &VerifyingKey, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: serde::Serializer,
+ {
+ let key = key.to_sec1_bytes();
+ let key = URL_SAFE.encode(key);
+ key.serialize(serializer)
+ }
+
+ pub fn deserialize<'de, D>(deserializer: D) -> Result<VerifyingKey, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_str(Visitor)
+ }
+
+ struct Visitor;
+ impl de::Visitor<'_> for Visitor {
+ type Value = VerifyingKey;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("a string containing a VAPID key")
+ }
+
+ fn visit_str<E>(self, key: &str) -> Result<Self::Value, E>
+ where
+ E: de::Error,
+ {
+ let key = URL_SAFE.decode(key).map_err(E::custom)?;
+ let key = VerifyingKey::from_sec1_bytes(&key).map_err(E::custom)?;
+
+ Ok(key)
+ }
+ }
+}