summaryrefslogtreecommitdiff
path: root/src/vapid/repo.rs
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2025-11-08 16:28:10 -0500
committerOwen Jacobson <owen@grimoire.ca>2025-11-08 16:28:10 -0500
commitfc6914831743f6d683c59adb367479defe6f8b3a (patch)
tree5b997adac55f47b52f30022013b8ec3b2c10bcc5 /src/vapid/repo.rs
parent0ef69c7d256380e660edc45ace7f1d6151226340 (diff)
parent6bab5b4405c9adafb2ce76540595a62eea80acc0 (diff)
Integrate the prototype push notification support.
We're going to move forwards with this for now, as low-utility as it is, so that we can more easily iterate on it in a real-world environment (hi.grimoire.ca).
Diffstat (limited to 'src/vapid/repo.rs')
-rw-r--r--src/vapid/repo.rs161
1 files changed, 161 insertions, 0 deletions
diff --git a/src/vapid/repo.rs b/src/vapid/repo.rs
new file mode 100644
index 0000000..9db61e1
--- /dev/null
+++ b/src/vapid/repo.rs
@@ -0,0 +1,161 @@
+use std::io::Cursor;
+
+use p256::{
+ ecdsa::SigningKey,
+ pkcs8::{DecodePrivateKey as _, EncodePrivateKey as _, LineEnding},
+};
+use sqlx::{Sqlite, SqliteConnection, Transaction};
+use web_push::{PartialVapidSignatureBuilder, VapidSignatureBuilder};
+
+use super::{
+ History,
+ event::{Changed, Event},
+};
+use crate::{
+ clock::DateTime,
+ db::NotFound,
+ event::{Instant, Sequence},
+};
+
+pub trait Provider {
+ fn vapid(&mut self) -> Vapid<'_>;
+}
+
+impl Provider for Transaction<'_, Sqlite> {
+ fn vapid(&mut self) -> Vapid<'_> {
+ Vapid(self)
+ }
+}
+
+pub struct Vapid<'a>(&'a mut SqliteConnection);
+
+impl Vapid<'_> {
+ pub async fn record_events(
+ &mut self,
+ events: impl IntoIterator<Item = Event>,
+ ) -> Result<(), sqlx::Error> {
+ for event in events {
+ self.record_event(&event).await?;
+ }
+ Ok(())
+ }
+
+ pub async fn record_event(&mut self, event: &Event) -> Result<(), sqlx::Error> {
+ match event {
+ Event::Changed(changed) => self.record_changed(changed).await,
+ }
+ }
+
+ async fn record_changed(&mut self, changed: &Changed) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ insert into vapid_key (changed_at, changed_sequence)
+ values ($1, $2)
+ "#,
+ changed.instant.at,
+ changed.instant.sequence,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn clear(&mut self) -> Result<(), sqlx::Error> {
+ sqlx::query!(
+ r#"
+ delete from vapid_key
+ "#
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ sqlx::query!(
+ r#"
+ delete from vapid_signing_key
+ "#
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn store_signing_key(&mut self, key: &SigningKey) -> Result<(), Error> {
+ let key = key.to_pkcs8_pem(LineEnding::CRLF)?;
+ let key = key.as_str();
+ sqlx::query!(
+ r#"
+ insert into vapid_signing_key (key)
+ values ($1)
+ "#,
+ key,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
+ pub async fn current(&mut self) -> Result<History, Error> {
+ let key = sqlx::query!(
+ r#"
+ select
+ key.changed_at as "changed_at: DateTime",
+ key.changed_sequence as "changed_sequence: Sequence",
+ signing.key
+ from vapid_key as key
+ join vapid_signing_key as signing
+ "#
+ )
+ .map(|row| {
+ let key = SigningKey::from_pkcs8_pem(&row.key)?;
+ let key = key.verifying_key().to_owned();
+
+ let changed = Instant::new(row.changed_at, row.changed_sequence);
+
+ Ok::<_, Error>(History { key, changed })
+ })
+ .fetch_one(&mut *self.0)
+ .await??;
+
+ Ok(key)
+ }
+
+ pub async fn signer(&mut self) -> Result<PartialVapidSignatureBuilder, Error> {
+ let key = sqlx::query_scalar!(
+ r#"
+ select key
+ from vapid_signing_key
+ "#
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+ let key = Cursor::new(&key);
+ let signer = VapidSignatureBuilder::from_pem_no_sub(key)?;
+
+ Ok(signer)
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum Error {
+ Ecdsa(#[from] p256::ecdsa::Error),
+ Pkcs8(#[from] p256::pkcs8::Error),
+ WebPush(#[from] web_push::WebPushError),
+ Database(#[from] sqlx::Error),
+}
+
+impl<T> NotFound for Result<T, Error> {
+ type Ok = T;
+ type Error = Error;
+
+ fn optional(self) -> Result<Option<T>, Error> {
+ match self {
+ Ok(value) => Ok(Some(value)),
+ Err(Error::Database(sqlx::Error::RowNotFound)) => Ok(None),
+ Err(other) => Err(other),
+ }
+ }
+}