summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/app.rs11
-rw-r--r--src/lib.rs1
-rw-r--r--src/login/app.rs2
-rw-r--r--src/push/app.rs76
-rw-r--r--src/push/handlers/mod.rs3
-rw-r--r--src/push/handlers/subscribe/mod.rs95
-rw-r--r--src/push/handlers/subscribe/test.rs236
-rw-r--r--src/push/mod.rs3
-rw-r--r--src/push/repo.rs114
-rw-r--r--src/routes.rs3
-rw-r--r--src/token/app.rs3
-rw-r--r--src/token/repo/token.rs17
-rw-r--r--src/vapid/app.rs5
-rw-r--r--src/vapid/ser.rs30
14 files changed, 596 insertions, 3 deletions
diff --git a/src/app.rs b/src/app.rs
index 2bfabbe..e24331b 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -10,6 +10,7 @@ use crate::{
invite::app::Invites,
login::app::Logins,
message::app::Messages,
+ push::app::Push,
setup::app::Setup,
token::{self, app::Tokens},
vapid::app::Vapid,
@@ -59,6 +60,10 @@ impl App {
Messages::new(self.db.clone(), self.events.clone())
}
+ pub fn push(&self) -> Push {
+ Push::new(self.db.clone())
+ }
+
pub fn setup(&self) -> Setup {
Setup::new(self.db.clone(), self.events.clone())
}
@@ -107,6 +112,12 @@ impl FromRef<App> for Messages {
}
}
+impl FromRef<App> for Push {
+ fn from_ref(app: &App) -> Self {
+ app.push()
+ }
+}
+
impl FromRef<App> for Setup {
fn from_ref(app: &App) -> Self {
app.setup()
diff --git a/src/lib.rs b/src/lib.rs
index 6b2a83c..38e6bc5 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -20,6 +20,7 @@ mod message;
mod name;
mod normalize;
mod password;
+mod push;
mod routes;
mod setup;
#[cfg(test)]
diff --git a/src/login/app.rs b/src/login/app.rs
index a2f9636..8cc8cd0 100644
--- a/src/login/app.rs
+++ b/src/login/app.rs
@@ -6,6 +6,7 @@ use crate::{
login::{self, Login, repo::Provider as _},
name::{self, Name},
password::Password,
+ push::repo::Provider as _,
token::{Broadcaster, Event as TokenEvent, Secret, Token, repo::Provider as _},
};
@@ -76,6 +77,7 @@ impl Logins {
let mut tx = self.db.begin().await?;
tx.logins().set_password(&login, &to_hash).await?;
+ tx.push().unsubscribe_login(&login).await?;
let revoked = tx.tokens().revoke_all(&login).await?;
tx.tokens().create(&token, &secret).await?;
tx.commit().await?;
diff --git a/src/push/app.rs b/src/push/app.rs
new file mode 100644
index 0000000..358a8cc
--- /dev/null
+++ b/src/push/app.rs
@@ -0,0 +1,76 @@
+use p256::ecdsa::VerifyingKey;
+use sqlx::SqlitePool;
+use web_push::SubscriptionInfo;
+
+use super::repo::Provider as _;
+use crate::{token::extract::Identity, vapid, vapid::repo::Provider as _};
+
+pub struct Push {
+ db: SqlitePool,
+}
+
+impl Push {
+ pub const fn new(db: SqlitePool) -> Self {
+ Self { db }
+ }
+
+ pub async fn subscribe(
+ &self,
+ subscriber: &Identity,
+ subscription: &SubscriptionInfo,
+ vapid: &VerifyingKey,
+ ) -> Result<(), SubscribeError> {
+ let mut tx = self.db.begin().await?;
+
+ let current = tx.vapid().current().await?;
+ if vapid != &current.key {
+ return Err(SubscribeError::StaleVapidKey(current.key));
+ }
+
+ match tx.push().create(&subscriber.token, subscription).await {
+ Ok(()) => (),
+ Err(err) => {
+ if let Some(err) = err.as_database_error()
+ && err.is_unique_violation()
+ {
+ let current = tx
+ .push()
+ .by_endpoint(&subscriber.login, &subscription.endpoint)
+ .await?;
+ // If we already have a subscription for this endpoint, with _different_
+ // parameters, then this is a client error. They shouldn't reuse endpoint URLs,
+ // per the various RFCs.
+ //
+ // However, if we have a subscription for this endpoint with the same parameters
+ // then we accept it and silently do nothing. This may happen if, for example,
+ // the subscribe request is retried due to a network interruption where it's
+ // not clear whether the original request succeeded.
+ if &current != subscription {
+ return Err(SubscribeError::Duplicate);
+ }
+ } else {
+ return Err(SubscribeError::Database(err));
+ }
+ }
+ }
+
+ tx.commit().await?;
+
+ Ok(())
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum SubscribeError {
+ #[error(transparent)]
+ Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Vapid(#[from] vapid::repo::Error),
+ #[error("subscription created with stale VAPID key")]
+ StaleVapidKey(VerifyingKey),
+ #[error("subscription already exists for endpoint")]
+ // The endpoint URL is not included in the error, as it is a bearer credential in its own right
+ // and we want to limit its proliferation. The only intended recipient of this message is the
+ // client, which already knows the endpoint anyways and doesn't need us to tell them.
+ Duplicate,
+}
diff --git a/src/push/handlers/mod.rs b/src/push/handlers/mod.rs
new file mode 100644
index 0000000..86eeea0
--- /dev/null
+++ b/src/push/handlers/mod.rs
@@ -0,0 +1,3 @@
+mod subscribe;
+
+pub use subscribe::handler as subscribe;
diff --git a/src/push/handlers/subscribe/mod.rs b/src/push/handlers/subscribe/mod.rs
new file mode 100644
index 0000000..d142df6
--- /dev/null
+++ b/src/push/handlers/subscribe/mod.rs
@@ -0,0 +1,95 @@
+use axum::{
+ extract::{Json, State},
+ http::StatusCode,
+ response::{IntoResponse, Response},
+};
+use p256::ecdsa::VerifyingKey;
+use web_push::SubscriptionInfo;
+
+use crate::{
+ error::Internal,
+ push::{app, app::Push},
+ token::extract::Identity,
+};
+
+#[cfg(test)]
+mod test;
+
+#[derive(Clone, serde::Deserialize)]
+pub struct Request {
+ subscription: Subscription,
+ #[serde(with = "crate::vapid::ser::key")]
+ vapid: VerifyingKey,
+}
+
+// This structure is described in <https://w3c.github.io/push-api/#dom-pushsubscription-tojson>.
+#[derive(Clone, serde::Deserialize)]
+pub struct Subscription {
+ endpoint: String,
+ keys: Keys,
+}
+
+// This structure is described in <https://w3c.github.io/push-api/#dom-pushsubscription-tojson>.
+#[derive(Clone, serde::Deserialize)]
+pub struct Keys {
+ p256dh: String,
+ auth: String,
+}
+
+pub async fn handler(
+ State(push): State<Push>,
+ identity: Identity,
+ Json(request): Json<Request>,
+) -> Result<StatusCode, Error> {
+ let Request {
+ subscription,
+ vapid,
+ } = request;
+
+ push.subscribe(&identity, &subscription.into(), &vapid)
+ .await?;
+
+ Ok(StatusCode::CREATED)
+}
+
+impl From<Subscription> for SubscriptionInfo {
+ fn from(request: Subscription) -> Self {
+ let Subscription {
+ endpoint,
+ keys: Keys { p256dh, auth },
+ } = request;
+ let info = SubscriptionInfo::new(endpoint, p256dh, auth);
+ info
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub struct Error(#[from] app::SubscribeError);
+
+impl IntoResponse for Error {
+ fn into_response(self) -> Response {
+ let Self(err) = self;
+
+ match err {
+ app::SubscribeError::StaleVapidKey(key) => {
+ let body = StaleVapidKey {
+ message: err.to_string(),
+ key,
+ };
+ (StatusCode::BAD_REQUEST, Json(body)).into_response()
+ }
+ app::SubscribeError::Duplicate => {
+ (StatusCode::CONFLICT, err.to_string()).into_response()
+ }
+ other => Internal::from(other).into_response(),
+ }
+ }
+}
+
+#[derive(serde::Serialize)]
+struct StaleVapidKey {
+ message: String,
+ #[serde(with = "crate::vapid::ser::key")]
+ key: VerifyingKey,
+}
diff --git a/src/push/handlers/subscribe/test.rs b/src/push/handlers/subscribe/test.rs
new file mode 100644
index 0000000..b72624d
--- /dev/null
+++ b/src/push/handlers/subscribe/test.rs
@@ -0,0 +1,236 @@
+use axum::{
+ extract::{Json, State},
+ http::StatusCode,
+};
+
+use crate::{
+ push::app::SubscribeError,
+ test::{fixtures, fixtures::event},
+};
+
+#[tokio::test]
+async fn accepts_new_subscription() {
+ let app = fixtures::scratch_app().await;
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ // Issue a VAPID key.
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what that VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Create a dummy subscription with that key.
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("test-p256dh-value"),
+ auth: String::from("test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ let response = super::handler(State(app.push()), subscriber, Json(request))
+ .await
+ .expect("test request will succeed on a fresh app");
+
+ // Check that the response looks as expected.
+
+ assert_eq!(StatusCode::CREATED, response);
+}
+
+#[tokio::test]
+async fn accepts_repeat_subscription() {
+ let app = fixtures::scratch_app().await;
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ // Issue a VAPID key.
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what that VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Create a dummy subscription with that key.
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("test-p256dh-value"),
+ auth: String::from("test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ let response = super::handler(State(app.push()), subscriber.clone(), Json(request.clone()))
+ .await
+ .expect("test request will succeed on a fresh app");
+
+ // Check that the response looks as expected.
+
+ assert_eq!(StatusCode::CREATED, response);
+
+ // Repeat the request
+
+ let response = super::handler(State(app.push()), subscriber, Json(request))
+ .await
+ .expect("test request will succeed twice on a fresh app");
+
+ // Check that the second response also looks as expected.
+
+ assert_eq!(StatusCode::CREATED, response);
+}
+
+#[tokio::test]
+async fn rejects_duplicate_subscription() {
+ let app = fixtures::scratch_app().await;
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ // Issue a VAPID key.
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what that VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Create a dummy subscription with that key.
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("test-p256dh-value"),
+ auth: String::from("test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ super::handler(State(app.push()), subscriber.clone(), Json(request))
+ .await
+ .expect("test request will succeed on a fresh app");
+
+ // Repeat the request with different keys
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("different-test-p256dh-value"),
+ auth: String::from("different-test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ let response = super::handler(State(app.push()), subscriber, Json(request))
+ .await
+ .expect_err("request with duplicate endpoint should fail");
+
+ // Make sure we got the error we expected.
+
+ assert!(matches!(response, super::Error(SubscribeError::Duplicate)));
+}
+
+#[tokio::test]
+async fn rejects_stale_vapid_key() {
+ let app = fixtures::scratch_app().await;
+ let subscriber = fixtures::identity::create(&app, &fixtures::now()).await;
+
+ // Issue a VAPID key.
+
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what that VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Change the VAPID key.
+
+ app.vapid()
+ .rotate_key()
+ .await
+ .expect("key rotation always succeeds");
+ app.vapid()
+ .refresh_key(&fixtures::now())
+ .await
+ .expect("refreshing the VAPID key always succeeds");
+
+ // Find out what the new VAPID key is.
+
+ let boot = app.boot().snapshot().await.expect("boot always succeeds");
+ let fresh_vapid = boot
+ .events
+ .into_iter()
+ .filter_map(event::vapid)
+ .filter_map(event::vapid::changed)
+ .next_back()
+ .expect("the application will have a vapid key after a refresh");
+
+ // Create a dummy subscription with the original key.
+
+ let request = super::Request {
+ subscription: super::Subscription {
+ endpoint: String::from("https://push.example.com/endpoint"),
+ keys: super::Keys {
+ p256dh: String::from("test-p256dh-value"),
+ auth: String::from("test-auth-value"),
+ },
+ },
+ vapid: vapid.key,
+ };
+ let response = super::handler(State(app.push()), subscriber, Json(request))
+ .await
+ .expect_err("test request has a stale vapid key");
+
+ // Check that the response looks as expected.
+
+ assert!(matches!(
+ response,
+ super::Error(SubscribeError::StaleVapidKey(key)) if key == fresh_vapid.key
+ ));
+}
diff --git a/src/push/mod.rs b/src/push/mod.rs
new file mode 100644
index 0000000..1394ea4
--- /dev/null
+++ b/src/push/mod.rs
@@ -0,0 +1,3 @@
+pub mod app;
+pub mod handlers;
+pub mod repo;
diff --git a/src/push/repo.rs b/src/push/repo.rs
new file mode 100644
index 0000000..6c18c6e
--- /dev/null
+++ b/src/push/repo.rs
@@ -0,0 +1,114 @@
+use sqlx::{Sqlite, SqliteConnection, Transaction};
+use web_push::SubscriptionInfo;
+
+use crate::{login::Login, token::Token};
+
+pub trait Provider {
+ fn push(&mut self) -> Push<'_>;
+}
+
+impl Provider for Transaction<'_, Sqlite> {
+ fn push(&mut self) -> Push<'_> {
+ Push(self)
+ }
+}
+
+pub struct Push<'t>(&'t mut SqliteConnection);
+
+impl Push<'_> {
+ pub async fn create(
+ &mut self,
+ token: &Token,
+ subscription: &SubscriptionInfo,
+ ) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ insert into push_subscription (token, endpoint, p256dh, auth)
+ values ($1, $2, $3, $4)
+ "#,
+ token.id,
+ subscription.endpoint,
+ subscription.keys.p256dh,
+ subscription.keys.auth,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn by_endpoint(
+ &mut self,
+ subscriber: &Login,
+ endpoint: &str,
+ ) -> Result<SubscriptionInfo, sqlx::Error> {
+ let row = sqlx::query!(
+ r#"
+ select
+ subscription.endpoint,
+ subscription.p256dh,
+ subscription.auth
+ from push_subscription as subscription
+ join token on subscription.token = token.id
+ join login as subscriber on token.login = subscriber.id
+ where subscriber.id = $1
+ and subscription.endpoint = $2
+ "#,
+ subscriber.id,
+ endpoint,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ let info = SubscriptionInfo::new(row.endpoint, row.p256dh, row.auth);
+
+ Ok(info)
+ }
+
+ pub async fn unsubscribe_token(&mut self, token: &Token) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ delete from push_subscription
+ where token = $1
+ "#,
+ token.id,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn unsubscribe_login(&mut self, login: &Login) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ with tokens as (
+ select id from token
+ where login = $1
+ )
+ delete from push_subscription
+ where token in tokens
+ "#,
+ login.id,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ // Unsubscribe logic for token expiry lives in the `tokens` repository, for maintenance reasons.
+
+ pub async fn clear(&mut self) -> Result<(), sqlx::Error> {
+ // We assume that _all_ stored subscriptions are for a VAPID key we're about to delete.
+ sqlx::query!(
+ r#"
+ delete from push_subscription
+ "#,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+}
diff --git a/src/routes.rs b/src/routes.rs
index 2979abe..00d9d3e 100644
--- a/src/routes.rs
+++ b/src/routes.rs
@@ -5,7 +5,7 @@ use axum::{
};
use crate::{
- app::App, boot, conversation, event, expire, invite, login, message, setup, ui, vapid,
+ app::App, boot, conversation, event, expire, invite, login, message, push, setup, ui, vapid,
};
pub fn routes(app: &App) -> Router<App> {
@@ -46,6 +46,7 @@ pub fn routes(app: &App) -> Router<App> {
.route("/api/invite/{invite}", get(invite::handlers::get))
.route("/api/invite/{invite}", post(invite::handlers::accept))
.route("/api/messages/{message}", delete(message::handlers::delete))
+ .route("/api/push/subscribe", post(push::handlers::subscribe))
.route("/api/password", post(login::handlers::change_password))
// Run expiry whenever someone accesses the API. This was previously a blanket middleware
// affecting the whole service, but loading the client makes a several requests before the
diff --git a/src/token/app.rs b/src/token/app.rs
index 332473d..4a08877 100644
--- a/src/token/app.rs
+++ b/src/token/app.rs
@@ -10,7 +10,7 @@ use super::{
extract::Identity,
repo::{self, Provider as _},
};
-use crate::{clock::DateTime, db::NotFound as _, name};
+use crate::{clock::DateTime, db::NotFound as _, name, push::repo::Provider as _};
pub struct Tokens {
db: SqlitePool,
@@ -112,6 +112,7 @@ impl Tokens {
pub async fn logout(&self, token: &Token) -> Result<(), ValidateError> {
let mut tx = self.db.begin().await?;
+ tx.push().unsubscribe_token(token).await?;
tx.tokens().revoke(token).await?;
tx.commit().await?;
diff --git a/src/token/repo/token.rs b/src/token/repo/token.rs
index 52a3987..33c33af 100644
--- a/src/token/repo/token.rs
+++ b/src/token/repo/token.rs
@@ -89,6 +89,23 @@ impl Tokens<'_> {
// Expire and delete all tokens that haven't been used more recently than
// `expire_at`.
pub async fn expire(&mut self, expire_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> {
+ // This lives here, rather than in the `push` repository, to ensure that the criteria for
+ // stale tokens don't drift apart between the two queries. That would be a larger risk if
+ // the queries lived in very separate parts of the codebase.
+ sqlx::query!(
+ r#"
+ with stale_tokens as (
+ select id from token
+ where last_used_at < $1
+ )
+ delete from push_subscription
+ where token in stale_tokens
+ "#,
+ expire_at,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
let tokens = sqlx::query_scalar!(
r#"
delete
diff --git a/src/vapid/app.rs b/src/vapid/app.rs
index 61523d5..7d872ed 100644
--- a/src/vapid/app.rs
+++ b/src/vapid/app.rs
@@ -6,6 +6,7 @@ use crate::{
clock::DateTime,
db::NotFound as _,
event::{Broadcaster, Sequence, repo::Provider},
+ push::repo::Provider as _,
};
pub struct Vapid {
@@ -60,6 +61,10 @@ impl Vapid {
let changed_at = tx.sequence().next(ensure_at).await?;
let (key, secret) = key.rotate(&changed_at);
+ // This will delete _all_ stored subscriptions. This is fine; they're all for the
+ // current VAPID key, and we won't be able to use them anyways once the key is rotated.
+ // We have no way to inform the push broker services of that, unfortunately.
+ tx.push().clear().await?;
tx.vapid().clear().await?;
tx.vapid().store_signing_key(&secret).await?;
diff --git a/src/vapid/ser.rs b/src/vapid/ser.rs
index f5372c8..02c77e1 100644
--- a/src/vapid/ser.rs
+++ b/src/vapid/ser.rs
@@ -1,7 +1,9 @@
pub mod key {
+ use std::fmt;
+
use base64::{Engine as _, engine::general_purpose::URL_SAFE};
use p256::ecdsa::VerifyingKey;
- use serde::Serialize as _;
+ use serde::{Deserializer, Serialize as _, de};
// This serialization - to a URL-safe base-64-encoded string and back - is based on my best
// understanding of RFC 8292 and the corresponding browser APIs. Particularly, it's based on
@@ -32,4 +34,30 @@ pub mod key {
let key = URL_SAFE.encode(key);
key.serialize(serializer)
}
+
+ pub fn deserialize<'de, D>(deserializer: D) -> Result<VerifyingKey, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_str(Visitor)
+ }
+
+ struct Visitor;
+ impl de::Visitor<'_> for Visitor {
+ type Value = VerifyingKey;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("a string containing a VAPID key")
+ }
+
+ fn visit_str<E>(self, key: &str) -> Result<Self::Value, E>
+ where
+ E: de::Error,
+ {
+ let key = URL_SAFE.decode(key).map_err(E::custom)?;
+ let key = VerifyingKey::from_sec1_bytes(&key).map_err(E::custom)?;
+
+ Ok(key)
+ }
+ }
}