diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/app.rs | 61 | ||||
| -rw-r--r-- | src/boot/app.rs | 2 | ||||
| -rw-r--r-- | src/cli.rs | 10 | ||||
| -rw-r--r-- | src/event/app.rs | 2 | ||||
| -rw-r--r-- | src/event/handlers/stream/mod.rs | 4 | ||||
| -rw-r--r-- | src/expire.rs | 4 | ||||
| -rw-r--r-- | src/push/app.rs | 114 | ||||
| -rw-r--r-- | src/push/handlers/mod.rs | 2 | ||||
| -rw-r--r-- | src/push/handlers/ping/mod.rs | 23 | ||||
| -rw-r--r-- | src/push/handlers/ping/test.rs | 40 | ||||
| -rw-r--r-- | src/push/handlers/subscribe/mod.rs | 7 | ||||
| -rw-r--r-- | src/push/repo.rs | 35 | ||||
| -rw-r--r-- | src/routes.rs | 7 | ||||
| -rw-r--r-- | src/test/fixtures/identity.rs | 2 | ||||
| -rw-r--r-- | src/test/fixtures/login.rs | 2 | ||||
| -rw-r--r-- | src/test/fixtures/mod.rs | 6 | ||||
| -rw-r--r-- | src/test/fixtures/user.rs | 4 | ||||
| -rw-r--r-- | src/test/mod.rs | 1 | ||||
| -rw-r--r-- | src/test/webpush.rs | 37 | ||||
| -rw-r--r-- | src/vapid/app.rs | 2 | ||||
| -rw-r--r-- | src/vapid/repo.rs | 19 |
21 files changed, 336 insertions, 48 deletions
@@ -17,25 +17,27 @@ use crate::{ }; #[derive(Clone)] -pub struct App { +pub struct App<P> { db: SqlitePool, + webpush: P, events: event::Broadcaster, token_events: token::Broadcaster, } -impl App { - pub fn from(db: SqlitePool) -> Self { +impl<P> App<P> { + pub fn from(db: SqlitePool, webpush: P) -> Self { let events = event::Broadcaster::default(); let token_events = token::Broadcaster::default(); Self { db, + webpush, events, token_events, } } } -impl App { +impl<P> App<P> { pub fn boot(&self) -> Boot { Boot::new(self.db.clone()) } @@ -60,8 +62,11 @@ impl App { Messages::new(self.db.clone(), self.events.clone()) } - pub fn push(&self) -> Push { - Push::new(self.db.clone()) + pub fn push(&self) -> Push<P> + where + P: Clone, + { + Push::new(self.db.clone(), self.webpush.clone()) } pub fn setup(&self) -> Setup { @@ -80,58 +85,66 @@ impl App { pub fn vapid(&self) -> Vapid { Vapid::new(self.db.clone(), self.events.clone()) } + + #[cfg(test)] + pub fn webpush(&self) -> &P { + &self.webpush + } } -impl FromRef<App> for Boot { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Boot { + fn from_ref(app: &App<P>) -> Self { app.boot() } } -impl FromRef<App> for Conversations { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Conversations { + fn from_ref(app: &App<P>) -> Self { app.conversations() } } -impl FromRef<App> for Invites { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Invites { + fn from_ref(app: &App<P>) -> Self { app.invites() } } -impl FromRef<App> for Logins { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Logins { + fn from_ref(app: &App<P>) -> Self { app.logins() } } -impl FromRef<App> for Messages { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Messages { + fn from_ref(app: &App<P>) -> Self { app.messages() } } -impl FromRef<App> for Push { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Push<P> +where + P: Clone, +{ + fn from_ref(app: &App<P>) -> Self { app.push() } } -impl FromRef<App> for Setup { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Setup { + fn from_ref(app: &App<P>) -> Self { app.setup() } } -impl FromRef<App> for Tokens { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Tokens { + fn from_ref(app: &App<P>) -> Self { app.tokens() } } -impl FromRef<App> for Vapid { - fn from_ref(app: &App) -> Self { +impl<P> FromRef<App<P>> for Vapid { + fn from_ref(app: &App<P>) -> Self { app.vapid() } } diff --git a/src/boot/app.rs b/src/boot/app.rs index 88255b0..1ca8adb 100644 --- a/src/boot/app.rs +++ b/src/boot/app.rs @@ -79,6 +79,7 @@ pub enum Error { Name(#[from] name::Error), Ecdsa(#[from] p256::ecdsa::Error), Pkcs8(#[from] p256::pkcs8::Error), + WebPush(#[from] web_push::WebPushError), } impl From<user::repo::LoadError> for Error { @@ -108,6 +109,7 @@ impl From<vapid::repo::Error> for Error { Error::Database(error) => error.into(), Error::Ecdsa(error) => error.into(), Error::Pkcs8(error) => error.into(), + Error::WebPush(error) => error.into(), } } } @@ -13,6 +13,7 @@ use axum::{ use clap::{CommandFactory, Parser, Subcommand}; use sqlx::sqlite::SqlitePool; use tokio::net; +use web_push::{IsahcWebPushClient, WebPushClient}; use crate::{ app::App, @@ -97,7 +98,8 @@ impl Args { self.umask.set(); let pool = self.pool().await?; - let app = App::from(pool); + let webpush = IsahcWebPushClient::new()?; + let app = App::from(pool, webpush); match self.command { None => self.serve(app).await?, @@ -107,7 +109,10 @@ impl Args { Result::<_, Error>::Ok(()) } - async fn serve(self, app: App) -> Result<(), Error> { + async fn serve<P>(self, app: App<P>) -> Result<(), Error> + where + P: WebPushClient + Clone + Send + Sync + 'static, + { let app = routes::routes(&app) .route_layer(middleware::from_fn(clock::middleware)) .route_layer(middleware::map_response(Self::server_info())) @@ -161,4 +166,5 @@ enum Error { Database(#[from] db::Error), Sqlx(#[from] sqlx::Error), Umask(#[from] umask::Error), + Webpush(#[from] web_push::WebPushError), } diff --git a/src/event/app.rs b/src/event/app.rs index 1e471f1..e422de9 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -99,6 +99,7 @@ pub enum Error { Name(#[from] name::Error), Ecdsa(#[from] p256::ecdsa::Error), Pkcs8(#[from] p256::pkcs8::Error), + WebPush(#[from] web_push::WebPushError), } impl From<user::repo::LoadError> for Error { @@ -128,6 +129,7 @@ impl From<vapid::repo::Error> for Error { Error::Database(error) => error.into(), Error::Ecdsa(error) => error.into(), Error::Pkcs8(error) => error.into(), + Error::WebPush(error) => error.into(), } } } diff --git a/src/event/handlers/stream/mod.rs b/src/event/handlers/stream/mod.rs index 63bfff3..8b89c31 100644 --- a/src/event/handlers/stream/mod.rs +++ b/src/event/handlers/stream/mod.rs @@ -18,8 +18,8 @@ use crate::{ #[cfg(test)] mod test; -pub async fn handler( - State(app): State<App>, +pub async fn handler<P>( + State(app): State<App<P>>, identity: Identity, last_event_id: Option<LastEventId<Sequence>>, Query(query): Query<QueryParams>, diff --git a/src/expire.rs b/src/expire.rs index 4177a53..c3b0117 100644 --- a/src/expire.rs +++ b/src/expire.rs @@ -7,8 +7,8 @@ use axum::{ use crate::{app::App, clock::RequestedAt, error::Internal}; // Expires messages and conversations before each request. -pub async fn middleware( - State(app): State<App>, +pub async fn middleware<P>( + State(app): State<App<P>>, RequestedAt(expired_at): RequestedAt, req: Request, next: Next, diff --git a/src/push/app.rs b/src/push/app.rs index 358a8cc..56b9a02 100644 --- a/src/push/app.rs +++ b/src/push/app.rs @@ -1,17 +1,23 @@ +use futures::future::join_all; +use itertools::Itertools as _; use p256::ecdsa::VerifyingKey; use sqlx::SqlitePool; -use web_push::SubscriptionInfo; +use web_push::{ + ContentEncoding, PartialVapidSignatureBuilder, SubscriptionInfo, WebPushClient, WebPushError, + WebPushMessage, WebPushMessageBuilder, +}; use super::repo::Provider as _; -use crate::{token::extract::Identity, vapid, vapid::repo::Provider as _}; +use crate::{login::Login, token::extract::Identity, vapid, vapid::repo::Provider as _}; -pub struct Push { +pub struct Push<P> { db: SqlitePool, + webpush: P, } -impl Push { - pub const fn new(db: SqlitePool) -> Self { - Self { db } +impl<P> Push<P> { + pub const fn new(db: SqlitePool, webpush: P) -> Self { + Self { db, webpush } } pub async fn subscribe( @@ -60,6 +66,76 @@ impl Push { } } +impl<P> Push<P> +where + P: WebPushClient, +{ + fn prepare_ping( + signer: &PartialVapidSignatureBuilder, + subscription: &SubscriptionInfo, + ) -> Result<WebPushMessage, WebPushError> { + let signature = signer.clone().add_sub_info(subscription).build()?; + + let payload = "ping".as_bytes(); + + let mut message = WebPushMessageBuilder::new(subscription); + message.set_payload(ContentEncoding::Aes128Gcm, payload); + message.set_vapid_signature(signature); + let message = message.build()?; + + Ok(message) + } + + pub async fn ping(&self, recipient: &Login) -> Result<(), PushError> { + let mut tx = self.db.begin().await?; + + let signer = tx.vapid().signer().await?; + let subscriptions = tx.push().by_login(recipient).await?; + + let pings: Vec<_> = subscriptions + .into_iter() + .map(|sub| Self::prepare_ping(&signer, &sub).map(|message| (sub, message))) + .try_collect()?; + + let deliveries = pings + .into_iter() + .map(async |(sub, message)| (sub, self.webpush.send(message).await)); + + let failures: Vec<_> = join_all(deliveries) + .await + .into_iter() + .filter_map(|(sub, result)| result.err().map(|err| (sub, err))) + .collect(); + + if !failures.is_empty() { + for (sub, err) in &failures { + match err { + // I _think_ this is the complete set of permanent failures. See + // <https://docs.rs/web-push/latest/web_push/enum.WebPushError.html> for a complete + // list. + WebPushError::Unauthorized(_) + | WebPushError::InvalidUri + | WebPushError::EndpointNotValid(_) + | WebPushError::EndpointNotFound(_) + | WebPushError::InvalidCryptoKeys + | WebPushError::MissingCryptoKeys => { + tx.push().unsubscribe(sub).await?; + } + _ => (), + } + } + + return Err(PushError::Delivery( + failures.into_iter().map(|(_, err)| err).collect(), + )); + } + + tx.commit().await?; + + Ok(()) + } +} + #[derive(Debug, thiserror::Error)] pub enum SubscribeError { #[error(transparent)] @@ -74,3 +150,29 @@ pub enum SubscribeError { // client, which already knows the endpoint anyways and doesn't need us to tell them. Duplicate, } + +#[derive(Debug, thiserror::Error)] +pub enum PushError { + #[error(transparent)] + Database(#[from] sqlx::Error), + #[error(transparent)] + Ecdsa(#[from] p256::ecdsa::Error), + #[error(transparent)] + Pkcs8(#[from] p256::pkcs8::Error), + #[error(transparent)] + WebPush(#[from] WebPushError), + #[error("push message delivery failures: {0:?}")] + Delivery(Vec<WebPushError>), +} + +impl From<vapid::repo::Error> for PushError { + fn from(error: vapid::repo::Error) -> Self { + use vapid::repo::Error; + match error { + Error::Database(error) => error.into(), + Error::Ecdsa(error) => error.into(), + Error::Pkcs8(error) => error.into(), + Error::WebPush(error) => error.into(), + } + } +} diff --git a/src/push/handlers/mod.rs b/src/push/handlers/mod.rs index 86eeea0..bb58774 100644 --- a/src/push/handlers/mod.rs +++ b/src/push/handlers/mod.rs @@ -1,3 +1,5 @@ +mod ping; mod subscribe; +pub use ping::handler as ping; pub use subscribe::handler as subscribe; diff --git a/src/push/handlers/ping/mod.rs b/src/push/handlers/ping/mod.rs new file mode 100644 index 0000000..db828fa --- /dev/null +++ b/src/push/handlers/ping/mod.rs @@ -0,0 +1,23 @@ +use axum::{Json, extract::State, http::StatusCode}; +use web_push::WebPushClient; + +use crate::{error::Internal, push::app::Push, token::extract::Identity}; + +#[cfg(test)] +mod test; + +#[derive(serde::Deserialize)] +pub struct Request {} + +pub async fn handler<P>( + State(push): State<Push<P>>, + identity: Identity, + Json(_): Json<Request>, +) -> Result<StatusCode, Internal> +where + P: WebPushClient, +{ + push.ping(&identity.login).await?; + + Ok(StatusCode::ACCEPTED) +} diff --git a/src/push/handlers/ping/test.rs b/src/push/handlers/ping/test.rs new file mode 100644 index 0000000..5725131 --- /dev/null +++ b/src/push/handlers/ping/test.rs @@ -0,0 +1,40 @@ +use axum::{ + extract::{Json, State}, + http::StatusCode, +}; + +use crate::test::fixtures; + +#[tokio::test] +async fn ping_without_subscriptions() { + let app = fixtures::scratch_app().await; + + let recipient = fixtures::identity::create(&app, &fixtures::now()).await; + + app.vapid() + .refresh_key(&fixtures::now()) + .await + .expect("refreshing the VAPID key always succeeds"); + + let response = super::handler(State(app.push()), recipient, Json(super::Request {})) + .await + .expect("sending a ping with no subscriptions always succeeds"); + + assert_eq!(StatusCode::ACCEPTED, response); + + assert!(app.webpush().sent().is_empty()); +} + +// More complete testing requires that we figure out how to generate working p256 ECDH keys for +// testing _with_, as `web_push` will actually parse and use those keys even if push messages are +// ultimately never serialized or sent over HTTP. +// +// Tests that are missing: +// +// * Verify that subscribing and sending a ping causes a ping to be delivered to that subscription. +// * Verify that two subscriptions both get pings. +// * Verify that other users' subscriptions are not pinged. +// * Verify that a ping that causes a permanent error causes the subscription to be deleted. +// * Verify that a ping that causes a non-permanent error does not cause the subscription to be +// deleted. +// * Verify that a failure on one subscription doesn't affect delivery on other subscriptions. diff --git a/src/push/handlers/subscribe/mod.rs b/src/push/handlers/subscribe/mod.rs index d142df6..a1a5899 100644 --- a/src/push/handlers/subscribe/mod.rs +++ b/src/push/handlers/subscribe/mod.rs @@ -36,8 +36,8 @@ pub struct Keys { auth: String, } -pub async fn handler( - State(push): State<Push>, +pub async fn handler<P>( + State(push): State<Push<P>>, identity: Identity, Json(request): Json<Request>, ) -> Result<StatusCode, Error> { @@ -58,8 +58,7 @@ impl From<Subscription> for SubscriptionInfo { endpoint, keys: Keys { p256dh, auth }, } = request; - let info = SubscriptionInfo::new(endpoint, p256dh, auth); - info + SubscriptionInfo::new(endpoint, p256dh, auth) } } diff --git a/src/push/repo.rs b/src/push/repo.rs index 6c18c6e..4183489 100644 --- a/src/push/repo.rs +++ b/src/push/repo.rs @@ -37,6 +37,24 @@ impl Push<'_> { Ok(()) } + pub async fn by_login(&mut self, login: &Login) -> Result<Vec<SubscriptionInfo>, sqlx::Error> { + sqlx::query!( + r#" + select + subscription.endpoint, + subscription.p256dh, + subscription.auth + from push_subscription as subscription + join token on subscription.token = token.id + where token.login = $1 + "#, + login.id, + ) + .map(|row| SubscriptionInfo::new(row.endpoint, row.p256dh, row.auth)) + .fetch_all(&mut *self.0) + .await + } + pub async fn by_endpoint( &mut self, subscriber: &Login, @@ -65,6 +83,23 @@ impl Push<'_> { Ok(info) } + pub async fn unsubscribe( + &mut self, + subscription: &SubscriptionInfo, + ) -> Result<(), sqlx::Error> { + sqlx::query!( + r#" + delete from push_subscription + where endpoint = $1 + "#, + subscription.endpoint, + ) + .execute(&mut *self.0) + .await?; + + Ok(()) + } + pub async fn unsubscribe_token(&mut self, token: &Token) -> Result<(), sqlx::Error> { sqlx::query!( r#" diff --git a/src/routes.rs b/src/routes.rs index 00d9d3e..1c07e78 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -3,12 +3,16 @@ use axum::{ response::Redirect, routing::{delete, get, post}, }; +use web_push::WebPushClient; use crate::{ app::App, boot, conversation, event, expire, invite, login, message, push, setup, ui, vapid, }; -pub fn routes(app: &App) -> Router<App> { +pub fn routes<P>(app: &App<P>) -> Router<App<P>> +where + P: WebPushClient + Clone + Send + Sync + 'static, +{ // UI routes that can be accessed before the administrator completes setup. let ui_bootstrap = Router::new() .route("/{*path}", get(ui::handlers::asset)) @@ -46,6 +50,7 @@ pub fn routes(app: &App) -> Router<App> { .route("/api/invite/{invite}", get(invite::handlers::get)) .route("/api/invite/{invite}", post(invite::handlers::accept)) .route("/api/messages/{message}", delete(message::handlers::delete)) + .route("/api/push/ping", post(push::handlers::ping)) .route("/api/push/subscribe", post(push::handlers::subscribe)) .route("/api/password", post(login::handlers::change_password)) // Run expiry whenever someone accesses the API. This was previously a blanket middleware diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs index 20929f9..adc3e73 100644 --- a/src/test/fixtures/identity.rs +++ b/src/test/fixtures/identity.rs @@ -14,7 +14,7 @@ use crate::{ }, }; -pub async fn create(app: &App, created_at: &RequestedAt) -> Identity { +pub async fn create<P>(app: &App<P>, created_at: &RequestedAt) -> Identity { let credentials = fixtures::user::create_with_password(app, created_at).await; logged_in(app, &credentials, created_at).await } diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs index d9aca81..839a412 100644 --- a/src/test/fixtures/login.rs +++ b/src/test/fixtures/login.rs @@ -5,7 +5,7 @@ use crate::{ test::fixtures::user::{propose, propose_name}, }; -pub async fn create(app: &App, created_at: &DateTime) -> Login { +pub async fn create<P>(app: &App<P>, created_at: &DateTime) -> Login { let (name, password) = propose(); app.users() .create(&name, &password, created_at) diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index 3d69cfa..53bf31b 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -1,6 +1,6 @@ use chrono::{TimeDelta, Utc}; -use crate::{app::App, clock::RequestedAt, db}; +use crate::{app::App, clock::RequestedAt, db, test::webpush::Client}; pub mod boot; pub mod conversation; @@ -13,11 +13,11 @@ pub mod login; pub mod message; pub mod user; -pub async fn scratch_app() -> App { +pub async fn scratch_app() -> App<Client> { let pool = db::prepare("sqlite::memory:", "sqlite::memory:") .await .expect("setting up in-memory sqlite database"); - App::from(pool) + App::from(pool, Client::new()) } pub fn now() -> RequestedAt { diff --git a/src/test/fixtures/user.rs b/src/test/fixtures/user.rs index d4d8db4..3ad4436 100644 --- a/src/test/fixtures/user.rs +++ b/src/test/fixtures/user.rs @@ -3,7 +3,7 @@ use uuid::Uuid; use crate::{app::App, clock::RequestedAt, login::Login, name::Name, password::Password}; -pub async fn create_with_password(app: &App, created_at: &RequestedAt) -> (Name, Password) { +pub async fn create_with_password<P>(app: &App<P>, created_at: &RequestedAt) -> (Name, Password) { let (name, password) = propose(); let user = app .users() @@ -14,7 +14,7 @@ pub async fn create_with_password(app: &App, created_at: &RequestedAt) -> (Name, (user.name, password) } -pub async fn create(app: &App, created_at: &RequestedAt) -> Login { +pub async fn create<P>(app: &App<P>, created_at: &RequestedAt) -> Login { super::login::create(app, created_at).await } diff --git a/src/test/mod.rs b/src/test/mod.rs index ebbbfef..f798b9c 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,2 +1,3 @@ pub mod fixtures; pub mod verify; +pub mod webpush; diff --git a/src/test/webpush.rs b/src/test/webpush.rs new file mode 100644 index 0000000..c86d03f --- /dev/null +++ b/src/test/webpush.rs @@ -0,0 +1,37 @@ +use std::{ + mem, + sync::{Arc, Mutex}, +}; + +use web_push::{WebPushClient, WebPushError, WebPushMessage}; + +#[derive(Clone)] +pub struct Client { + sent: Arc<Mutex<Vec<WebPushMessage>>>, +} + +impl Client { + pub fn new() -> Self { + Self { + sent: Arc::default(), + } + } + + // Clears the list of sent messages (for all clones of this Client) when called, because we + // can't clone `WebPushMessage`s so we either need to move them or try to reconstruct them, + // either of which sucks but moving them sucks less. + pub fn sent(&self) -> Vec<WebPushMessage> { + let mut sent = self.sent.lock().unwrap(); + mem::replace(&mut *sent, Vec::new()) + } +} + +#[async_trait::async_trait] +impl WebPushClient for Client { + async fn send(&self, message: WebPushMessage) -> Result<(), WebPushError> { + let mut sent = self.sent.lock().unwrap(); + sent.push(message); + + Ok(()) + } +} diff --git a/src/vapid/app.rs b/src/vapid/app.rs index ebd2446..9949aa5 100644 --- a/src/vapid/app.rs +++ b/src/vapid/app.rs @@ -101,6 +101,7 @@ pub enum Error { Database(#[from] sqlx::Error), Ecdsa(#[from] p256::ecdsa::Error), Pkcs8(#[from] p256::pkcs8::Error), + WebPush(#[from] web_push::WebPushError), } impl From<repo::Error> for Error { @@ -110,6 +111,7 @@ impl From<repo::Error> for Error { Error::Database(error) => error.into(), Error::Ecdsa(error) => error.into(), Error::Pkcs8(error) => error.into(), + Error::WebPush(error) => error.into(), } } } diff --git a/src/vapid/repo.rs b/src/vapid/repo.rs index 98b3bae..9db61e1 100644 --- a/src/vapid/repo.rs +++ b/src/vapid/repo.rs @@ -1,8 +1,11 @@ +use std::io::Cursor; + use p256::{ ecdsa::SigningKey, pkcs8::{DecodePrivateKey as _, EncodePrivateKey as _, LineEnding}, }; use sqlx::{Sqlite, SqliteConnection, Transaction}; +use web_push::{PartialVapidSignatureBuilder, VapidSignatureBuilder}; use super::{ History, @@ -118,6 +121,21 @@ impl Vapid<'_> { Ok(key) } + + pub async fn signer(&mut self) -> Result<PartialVapidSignatureBuilder, Error> { + let key = sqlx::query_scalar!( + r#" + select key + from vapid_signing_key + "# + ) + .fetch_one(&mut *self.0) + .await?; + let key = Cursor::new(&key); + let signer = VapidSignatureBuilder::from_pem_no_sub(key)?; + + Ok(signer) + } } #[derive(Debug, thiserror::Error)] @@ -125,6 +143,7 @@ impl Vapid<'_> { pub enum Error { Ecdsa(#[from] p256::ecdsa::Error), Pkcs8(#[from] p256::pkcs8::Error), + WebPush(#[from] web_push::WebPushError), Database(#[from] sqlx::Error), } |
