summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/app.rs28
-rw-r--r--src/broadcast.rs4
-rw-r--r--src/channel/app.rs129
-rw-r--r--src/channel/event.rs48
-rw-r--r--src/channel/history.rs42
-rw-r--r--src/channel/id.rs38
-rw-r--r--src/channel/mod.rs7
-rw-r--r--src/channel/repo.rs202
-rw-r--r--src/channel/routes.rs122
-rw-r--r--src/channel/routes/test/list.rs7
-rw-r--r--src/channel/routes/test/on_create.rs13
-rw-r--r--src/channel/routes/test/on_send.rs23
-rw-r--r--src/channel/snapshot.rs38
-rw-r--r--src/cli.rs15
-rw-r--r--src/db.rs (renamed from src/repo/pool.rs)24
-rw-r--r--src/error.rs8
-rw-r--r--src/event/app.rs72
-rw-r--r--src/event/broadcaster.rs3
-rw-r--r--src/event/extract.rs (renamed from src/events/extract.rs)0
-rw-r--r--src/event/mod.rs75
-rw-r--r--src/event/repo.rs50
-rw-r--r--src/event/routes.rs92
-rw-r--r--src/event/routes/test.rs (renamed from src/events/routes/test.rs)172
-rw-r--r--src/event/sequence.rs90
-rw-r--r--src/events/app.rs163
-rw-r--r--src/events/broadcaster.rs3
-rw-r--r--src/events/mod.rs8
-rw-r--r--src/events/repo/message.rs198
-rw-r--r--src/events/repo/mod.rs1
-rw-r--r--src/events/routes.rs69
-rw-r--r--src/events/types.rs170
-rw-r--r--src/expire.rs4
-rw-r--r--src/lib.rs7
-rw-r--r--src/login/app.rs140
-rw-r--r--src/login/broadcaster.rs3
-rw-r--r--src/login/extract.rs181
-rw-r--r--src/login/id.rs24
-rw-r--r--src/login/mod.rs23
-rw-r--r--src/login/password.rs (renamed from src/password.rs)0
-rw-r--r--src/login/repo.rs50
-rw-r--r--src/login/repo/mod.rs1
-rw-r--r--src/login/routes.rs28
-rw-r--r--src/login/routes/test/boot.rs7
-rw-r--r--src/login/routes/test/login.rs13
-rw-r--r--src/login/routes/test/logout.rs7
-rw-r--r--src/message/app.rs115
-rw-r--r--src/message/event.rs71
-rw-r--r--src/message/history.rs41
-rw-r--r--src/message/id.rs (renamed from src/repo/message.rs)6
-rw-r--r--src/message/mod.rs9
-rw-r--r--src/message/repo.rs247
-rw-r--r--src/message/routes.rs46
-rw-r--r--src/message/snapshot.rs76
-rw-r--r--src/repo/channel.rs179
-rw-r--r--src/repo/error.rs23
-rw-r--r--src/repo/login/extract.rs15
-rw-r--r--src/repo/login/mod.rs4
-rw-r--r--src/repo/login/store.rs86
-rw-r--r--src/repo/mod.rs6
-rw-r--r--src/test/fixtures/channel.rs2
-rw-r--r--src/test/fixtures/event.rs11
-rw-r--r--src/test/fixtures/filter.rs14
-rw-r--r--src/test/fixtures/identity.rs13
-rw-r--r--src/test/fixtures/login.rs3
-rw-r--r--src/test/fixtures/message.rs18
-rw-r--r--src/test/fixtures/mod.rs5
-rw-r--r--src/token/app.rs170
-rw-r--r--src/token/broadcaster.rs4
-rw-r--r--src/token/event.rs (renamed from src/login/types.rs)2
-rw-r--r--src/token/extract/identity.rs75
-rw-r--r--src/token/extract/identity_token.rs94
-rw-r--r--src/token/extract/mod.rs4
-rw-r--r--src/token/id.rs27
-rw-r--r--src/token/mod.rs9
-rw-r--r--src/token/repo/auth.rs (renamed from src/login/repo/auth.rs)5
-rw-r--r--src/token/repo/mod.rs4
-rw-r--r--src/token/repo/token.rs (renamed from src/repo/token.rs)54
-rw-r--r--src/token/secret.rs27
78 files changed, 2384 insertions, 1483 deletions
diff --git a/src/app.rs b/src/app.rs
index c13f52f..186e5f8 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -2,35 +2,45 @@ use sqlx::sqlite::SqlitePool;
use crate::{
channel::app::Channels,
- events::{app::Events, broadcaster::Broadcaster as EventBroadcaster},
- login::{app::Logins, broadcaster::Broadcaster as LoginBroadcaster},
+ event::{app::Events, broadcaster::Broadcaster as EventBroadcaster},
+ login::app::Logins,
+ message::app::Messages,
+ token::{app::Tokens, broadcaster::Broadcaster as TokenBroadcaster},
};
#[derive(Clone)]
pub struct App {
db: SqlitePool,
events: EventBroadcaster,
- logins: LoginBroadcaster,
+ tokens: TokenBroadcaster,
}
impl App {
pub fn from(db: SqlitePool) -> Self {
let events = EventBroadcaster::default();
- let logins = LoginBroadcaster::default();
- Self { db, events, logins }
+ let tokens = TokenBroadcaster::default();
+ Self { db, events, tokens }
}
}
impl App {
- pub const fn logins(&self) -> Logins {
- Logins::new(&self.db, &self.logins)
+ pub const fn channels(&self) -> Channels {
+ Channels::new(&self.db, &self.events)
}
pub const fn events(&self) -> Events {
Events::new(&self.db, &self.events)
}
- pub const fn channels(&self) -> Channels {
- Channels::new(&self.db, &self.events)
+ pub const fn logins(&self) -> Logins {
+ Logins::new(&self.db)
+ }
+
+ pub const fn messages(&self) -> Messages {
+ Messages::new(&self.db, &self.events)
+ }
+
+ pub const fn tokens(&self) -> Tokens {
+ Tokens::new(&self.db, &self.tokens)
}
}
diff --git a/src/broadcast.rs b/src/broadcast.rs
index 083a301..bedc263 100644
--- a/src/broadcast.rs
+++ b/src/broadcast.rs
@@ -32,7 +32,7 @@ where
{
// panic: if ``message.channel.id`` has not been previously registered,
// and was not part of the initial set of channels.
- pub fn broadcast(&self, message: &M) {
+ pub fn broadcast(&self, message: impl Into<M>) {
let tx = self.sender();
// Per the Tokio docs, the returned error is only used to indicate that
@@ -42,7 +42,7 @@ where
//
// The successful return value, which includes the number of active
// receivers, also isn't that interesting to us.
- let _ = tx.send(message.clone());
+ let _ = tx.send(message.into());
}
// panic: if ``channel`` has not been previously registered, and was not
diff --git a/src/channel/app.rs b/src/channel/app.rs
index 70cda47..bb331ec 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -1,10 +1,13 @@
use chrono::TimeDelta;
+use itertools::Itertools;
use sqlx::sqlite::SqlitePool;
+use super::{repo::Provider as _, Channel, Id};
use crate::{
clock::DateTime,
- events::{broadcaster::Broadcaster, repo::message::Provider as _, types::ChannelEvent},
- repo::channel::{Channel, Provider as _},
+ db::NotFound,
+ event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence},
+ message::{repo::Provider as _, Message},
};
pub struct Channels<'a> {
@@ -19,27 +22,108 @@ impl<'a> Channels<'a> {
pub async fn create(&self, name: &str, created_at: &DateTime) -> Result<Channel, CreateError> {
let mut tx = self.db.begin().await?;
+ let created = tx.sequence().next(created_at).await?;
let channel = tx
.channels()
- .create(name, created_at)
+ .create(name, &created)
.await
.map_err(|err| CreateError::from_duplicate_name(err, name))?;
tx.commit().await?;
self.events
- .broadcast(&ChannelEvent::created(channel.clone()));
+ .broadcast(channel.events().map(Event::from).collect::<Vec<_>>());
- Ok(channel)
+ Ok(channel.snapshot())
}
- pub async fn all(&self) -> Result<Vec<Channel>, InternalError> {
+ pub async fn all(&self, resume_point: Option<Sequence>) -> Result<Vec<Channel>, InternalError> {
let mut tx = self.db.begin().await?;
- let channels = tx.channels().all().await?;
+ let channels = tx.channels().all(resume_point).await?;
tx.commit().await?;
+ let channels = channels
+ .into_iter()
+ .filter_map(|channel| {
+ channel
+ .events()
+ .filter(Sequence::up_to(resume_point))
+ .collect()
+ })
+ .collect();
+
Ok(channels)
}
+ pub async fn messages(
+ &self,
+ channel: &Id,
+ resume_point: Option<Sequence>,
+ ) -> Result<Vec<Message>, Error> {
+ let mut tx = self.db.begin().await?;
+ let channel = tx
+ .channels()
+ .by_id(channel)
+ .await
+ .not_found(|| Error::NotFound(channel.clone()))?
+ .snapshot();
+
+ let messages = tx
+ .messages()
+ .in_channel(&channel, resume_point)
+ .await?
+ .into_iter()
+ .filter_map(|message| {
+ message
+ .events()
+ .filter(Sequence::up_to(resume_point))
+ .collect()
+ })
+ .collect();
+
+ Ok(messages)
+ }
+
+ pub async fn delete(&self, channel: &Id, deleted_at: &DateTime) -> Result<(), Error> {
+ let mut tx = self.db.begin().await?;
+
+ let channel = tx
+ .channels()
+ .by_id(channel)
+ .await
+ .not_found(|| Error::NotFound(channel.clone()))?
+ .snapshot();
+
+ let mut events = Vec::new();
+
+ let messages = tx.messages().in_channel(&channel, None).await?;
+ for message in messages {
+ let message = message.snapshot();
+ let deleted = tx.sequence().next(deleted_at).await?;
+ let message = tx.messages().delete(&message.id, &deleted).await?;
+ events.extend(
+ message
+ .events()
+ .filter(Sequence::start_from(deleted.sequence))
+ .map(Event::from),
+ );
+ }
+
+ let deleted = tx.sequence().next(deleted_at).await?;
+ let channel = tx.channels().delete(&channel.id, &deleted).await?;
+ events.extend(
+ channel
+ .events()
+ .filter(Sequence::start_from(deleted.sequence))
+ .map(Event::from),
+ );
+
+ tx.commit().await?;
+
+ self.events.broadcast(events);
+
+ Ok(())
+ }
+
pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> {
// Somewhat arbitrarily, expire after 90 days.
let expire_at = relative_to.to_owned() - TimeDelta::days(90);
@@ -49,19 +133,24 @@ impl<'a> Channels<'a> {
let mut events = Vec::with_capacity(expired.len());
for channel in expired {
- let sequence = tx.message_events().assign_sequence(&channel).await?;
- let event = tx
- .channels()
- .delete_expired(&channel, sequence, relative_to)
- .await?;
- events.push(event);
+ let deleted = tx.sequence().next(relative_to).await?;
+ let channel = tx.channels().delete(&channel, &deleted).await?;
+ events.push(
+ channel
+ .events()
+ .filter(Sequence::start_from(deleted.sequence)),
+ );
}
tx.commit().await?;
- for event in events {
- self.events.broadcast(&event);
- }
+ self.events.broadcast(
+ events
+ .into_iter()
+ .kmerge_by(Sequence::merge)
+ .map(Event::from)
+ .collect::<Vec<_>>(),
+ );
Ok(())
}
@@ -75,6 +164,14 @@ pub enum CreateError {
DatabaseError(#[from] sqlx::Error),
}
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("channel {0} not found")]
+ NotFound(Id),
+ #[error(transparent)]
+ DatabaseError(#[from] sqlx::Error),
+}
+
impl CreateError {
fn from_duplicate_name(error: sqlx::Error, name: &str) -> Self {
if let Some(error) = error.as_database_error() {
diff --git a/src/channel/event.rs b/src/channel/event.rs
new file mode 100644
index 0000000..9c54174
--- /dev/null
+++ b/src/channel/event.rs
@@ -0,0 +1,48 @@
+use super::Channel;
+use crate::{
+ channel,
+ event::{Instant, Sequenced},
+};
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Event {
+ #[serde(flatten)]
+ pub instant: Instant,
+ #[serde(flatten)]
+ pub kind: Kind,
+}
+
+impl Sequenced for Event {
+ fn instant(&self) -> Instant {
+ self.instant
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum Kind {
+ Created(Created),
+ Deleted(Deleted),
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Created {
+ pub channel: Channel,
+}
+
+impl From<Created> for Kind {
+ fn from(event: Created) -> Self {
+ Self::Created(event)
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Deleted {
+ pub channel: channel::Id,
+}
+
+impl From<Deleted> for Kind {
+ fn from(event: Deleted) -> Self {
+ Self::Deleted(event)
+ }
+}
diff --git a/src/channel/history.rs b/src/channel/history.rs
new file mode 100644
index 0000000..3cc7d9d
--- /dev/null
+++ b/src/channel/history.rs
@@ -0,0 +1,42 @@
+use super::{
+ event::{Created, Deleted, Event},
+ Channel,
+};
+use crate::event::Instant;
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct History {
+ pub channel: Channel,
+ pub created: Instant,
+ pub deleted: Option<Instant>,
+}
+
+impl History {
+ fn created(&self) -> Event {
+ Event {
+ instant: self.created,
+ kind: Created {
+ channel: self.channel.clone(),
+ }
+ .into(),
+ }
+ }
+
+ fn deleted(&self) -> Option<Event> {
+ self.deleted.map(|instant| Event {
+ instant,
+ kind: Deleted {
+ channel: self.channel.id.clone(),
+ }
+ .into(),
+ })
+ }
+
+ pub fn events(&self) -> impl Iterator<Item = Event> {
+ [self.created()].into_iter().chain(self.deleted())
+ }
+
+ pub fn snapshot(&self) -> Channel {
+ self.channel.clone()
+ }
+}
diff --git a/src/channel/id.rs b/src/channel/id.rs
new file mode 100644
index 0000000..22a2700
--- /dev/null
+++ b/src/channel/id.rs
@@ -0,0 +1,38 @@
+use std::fmt;
+
+use crate::id::Id as BaseId;
+
+// Stable identifier for a [Channel]. Prefixed with `C`.
+#[derive(
+ Clone,
+ Debug,
+ Eq,
+ Hash,
+ Ord,
+ PartialEq,
+ PartialOrd,
+ sqlx::Type,
+ serde::Deserialize,
+ serde::Serialize,
+)]
+#[sqlx(transparent)]
+#[serde(transparent)]
+pub struct Id(BaseId);
+
+impl From<BaseId> for Id {
+ fn from(id: BaseId) -> Self {
+ Self(id)
+ }
+}
+
+impl Id {
+ pub fn generate() -> Self {
+ BaseId::generate("C")
+ }
+}
+
+impl fmt::Display for Id {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
diff --git a/src/channel/mod.rs b/src/channel/mod.rs
index 9f79dbb..eb8200b 100644
--- a/src/channel/mod.rs
+++ b/src/channel/mod.rs
@@ -1,4 +1,9 @@
pub mod app;
+pub mod event;
+mod history;
+mod id;
+pub mod repo;
mod routes;
+mod snapshot;
-pub use self::routes::router;
+pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Channel};
diff --git a/src/channel/repo.rs b/src/channel/repo.rs
new file mode 100644
index 0000000..2b48436
--- /dev/null
+++ b/src/channel/repo.rs
@@ -0,0 +1,202 @@
+use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
+
+use crate::{
+ channel::{Channel, History, Id},
+ clock::DateTime,
+ event::{Instant, Sequence},
+};
+
+pub trait Provider {
+ fn channels(&mut self) -> Channels;
+}
+
+impl<'c> Provider for Transaction<'c, Sqlite> {
+ fn channels(&mut self) -> Channels {
+ Channels(self)
+ }
+}
+
+pub struct Channels<'t>(&'t mut SqliteConnection);
+
+impl<'c> Channels<'c> {
+ pub async fn create(&mut self, name: &str, created: &Instant) -> Result<History, sqlx::Error> {
+ let id = Id::generate();
+ let channel = sqlx::query!(
+ r#"
+ insert
+ into channel (id, name, created_at, created_sequence)
+ values ($1, $2, $3, $4)
+ returning
+ id as "id: Id",
+ name,
+ created_at as "created_at: DateTime",
+ created_sequence as "created_sequence: Sequence"
+ "#,
+ id,
+ name,
+ created.at,
+ created.sequence,
+ )
+ .map(|row| History {
+ channel: Channel {
+ id: row.id,
+ name: row.name,
+ },
+ created: Instant {
+ at: row.created_at,
+ sequence: row.created_sequence,
+ },
+ deleted: None,
+ })
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(channel)
+ }
+
+ pub async fn by_id(&mut self, channel: &Id) -> Result<History, sqlx::Error> {
+ let channel = sqlx::query!(
+ r#"
+ select
+ id as "id: Id",
+ name,
+ created_at as "created_at: DateTime",
+ created_sequence as "created_sequence: Sequence"
+ from channel
+ where id = $1
+ "#,
+ channel,
+ )
+ .map(|row| History {
+ channel: Channel {
+ id: row.id,
+ name: row.name,
+ },
+ created: Instant {
+ at: row.created_at,
+ sequence: row.created_sequence,
+ },
+ deleted: None,
+ })
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(channel)
+ }
+
+ pub async fn all(&mut self, resume_at: Option<Sequence>) -> Result<Vec<History>, sqlx::Error> {
+ let channels = sqlx::query!(
+ r#"
+ select
+ id as "id: Id",
+ name,
+ created_at as "created_at: DateTime",
+ created_sequence as "created_sequence: Sequence"
+ from channel
+ where coalesce(created_sequence <= $1, true)
+ order by channel.name
+ "#,
+ resume_at,
+ )
+ .map(|row| History {
+ channel: Channel {
+ id: row.id,
+ name: row.name,
+ },
+ created: Instant {
+ at: row.created_at,
+ sequence: row.created_sequence,
+ },
+ deleted: None,
+ })
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ Ok(channels)
+ }
+
+ pub async fn replay(
+ &mut self,
+ resume_at: Option<Sequence>,
+ ) -> Result<Vec<History>, sqlx::Error> {
+ let channels = sqlx::query!(
+ r#"
+ select
+ id as "id: Id",
+ name,
+ created_at as "created_at: DateTime",
+ created_sequence as "created_sequence: Sequence"
+ from channel
+ where coalesce(created_sequence > $1, true)
+ "#,
+ resume_at,
+ )
+ .map(|row| History {
+ channel: Channel {
+ id: row.id,
+ name: row.name,
+ },
+ created: Instant {
+ at: row.created_at,
+ sequence: row.created_sequence,
+ },
+ deleted: None,
+ })
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ Ok(channels)
+ }
+
+ pub async fn delete(
+ &mut self,
+ channel: &Id,
+ deleted: &Instant,
+ ) -> Result<History, sqlx::Error> {
+ let channel = sqlx::query!(
+ r#"
+ delete from channel
+ where id = $1
+ returning
+ id as "id: Id",
+ name,
+ created_at as "created_at: DateTime",
+ created_sequence as "created_sequence: Sequence"
+ "#,
+ channel,
+ )
+ .map(|row| History {
+ channel: Channel {
+ id: row.id,
+ name: row.name,
+ },
+ created: Instant {
+ at: row.created_at,
+ sequence: row.created_sequence,
+ },
+ deleted: Some(*deleted),
+ })
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(channel)
+ }
+
+ pub async fn expired(&mut self, expired_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> {
+ let channels = sqlx::query_scalar!(
+ r#"
+ select
+ channel.id as "id: Id"
+ from channel
+ left join message
+ where created_at < $1
+ and message.id is null
+ "#,
+ expired_at,
+ )
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ Ok(channels)
+ }
+}
diff --git a/src/channel/routes.rs b/src/channel/routes.rs
index 1f8db5a..23c0602 100644
--- a/src/channel/routes.rs
+++ b/src/channel/routes.rs
@@ -2,20 +2,19 @@ use axum::{
extract::{Json, Path, State},
http::StatusCode,
response::{IntoResponse, Response},
- routing::{get, post},
+ routing::{delete, get, post},
Router,
};
+use axum_extra::extract::Query;
-use super::app;
+use super::{app, Channel, Id};
use crate::{
app::App,
clock::RequestedAt,
error::Internal,
- events::app::EventsError,
- repo::{
- channel::{self, Channel},
- login::Login,
- },
+ event::{Instant, Sequence},
+ login::Login,
+ message::{self, app::SendError},
};
#[cfg(test)]
@@ -26,10 +25,21 @@ pub fn router() -> Router<App> {
.route("/api/channels", get(list))
.route("/api/channels", post(on_create))
.route("/api/channels/:channel", post(on_send))
+ .route("/api/channels/:channel", delete(on_delete))
+ .route("/api/channels/:channel/messages", get(messages))
}
-async fn list(State(app): State<App>, _: Login) -> Result<Channels, Internal> {
- let channels = app.channels().all().await?;
+#[derive(Default, serde::Deserialize)]
+struct ResumeQuery {
+ resume_point: Option<Sequence>,
+}
+
+async fn list(
+ State(app): State<App>,
+ _: Login,
+ Query(query): Query<ResumeQuery>,
+) -> Result<Channels, Internal> {
+ let channels = app.channels().all(query.resume_point).await?;
let response = Channels(channels);
Ok(response)
@@ -86,31 +96,107 @@ struct SendRequest {
async fn on_send(
State(app): State<App>,
- Path(channel): Path<channel::Id>,
+ Path(channel): Path<Id>,
RequestedAt(sent_at): RequestedAt,
login: Login,
Json(request): Json<SendRequest>,
+) -> Result<StatusCode, SendErrorResponse> {
+ app.messages()
+ .send(&channel, &login, &sent_at, &request.message)
+ .await?;
+
+ Ok(StatusCode::ACCEPTED)
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+struct SendErrorResponse(#[from] SendError);
+
+impl IntoResponse for SendErrorResponse {
+ fn into_response(self) -> Response {
+ let Self(error) = self;
+ match error {
+ not_found @ SendError::ChannelNotFound(_) => {
+ (StatusCode::NOT_FOUND, not_found.to_string()).into_response()
+ }
+ other => Internal::from(other).into_response(),
+ }
+ }
+}
+
+async fn on_delete(
+ State(app): State<App>,
+ Path(channel): Path<Id>,
+ RequestedAt(deleted_at): RequestedAt,
+ _: Login,
) -> Result<StatusCode, ErrorResponse> {
- app.events()
- .send(&login, &channel, &request.message, &sent_at)
- .await
- // Could impl `From` here, but it's more code and this is used once.
- .map_err(ErrorResponse)?;
+ app.channels().delete(&channel, &deleted_at).await?;
Ok(StatusCode::ACCEPTED)
}
-#[derive(Debug)]
-struct ErrorResponse(EventsError);
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+struct ErrorResponse(#[from] app::Error);
impl IntoResponse for ErrorResponse {
fn into_response(self) -> Response {
let Self(error) = self;
match error {
- not_found @ EventsError::ChannelNotFound(_) => {
+ not_found @ app::Error::NotFound(_) => {
(StatusCode::NOT_FOUND, not_found.to_string()).into_response()
}
other => Internal::from(other).into_response(),
}
}
}
+
+async fn messages(
+ State(app): State<App>,
+ Path(channel): Path<Id>,
+ _: Login,
+ Query(query): Query<ResumeQuery>,
+) -> Result<Messages, ErrorResponse> {
+ let messages = app
+ .channels()
+ .messages(&channel, query.resume_point)
+ .await?;
+ let response = Messages(
+ messages
+ .into_iter()
+ .map(|message| MessageView {
+ sent: message.sent,
+ sender: message.sender,
+ message: MessageInner {
+ id: message.id,
+ body: message.body,
+ },
+ })
+ .collect(),
+ );
+
+ Ok(response)
+}
+
+struct Messages(Vec<MessageView>);
+
+#[derive(serde::Serialize)]
+struct MessageView {
+ #[serde(flatten)]
+ sent: Instant,
+ sender: Login,
+ message: MessageInner,
+}
+
+#[derive(serde::Serialize)]
+struct MessageInner {
+ id: message::Id,
+ body: String,
+}
+
+impl IntoResponse for Messages {
+ fn into_response(self) -> Response {
+ let Self(messages) = self;
+ Json(messages).into_response()
+ }
+}
diff --git a/src/channel/routes/test/list.rs b/src/channel/routes/test/list.rs
index bc94024..f15a53c 100644
--- a/src/channel/routes/test/list.rs
+++ b/src/channel/routes/test/list.rs
@@ -1,4 +1,5 @@
use axum::extract::State;
+use axum_extra::extract::Query;
use crate::{channel::routes, test::fixtures};
@@ -11,7 +12,7 @@ async fn empty_list() {
// Call the endpoint
- let routes::Channels(channels) = routes::list(State(app), viewer)
+ let routes::Channels(channels) = routes::list(State(app), viewer, Query::default())
.await
.expect("always succeeds");
@@ -30,7 +31,7 @@ async fn one_channel() {
// Call the endpoint
- let routes::Channels(channels) = routes::list(State(app), viewer)
+ let routes::Channels(channels) = routes::list(State(app), viewer, Query::default())
.await
.expect("always succeeds");
@@ -52,7 +53,7 @@ async fn multiple_channels() {
// Call the endpoint
- let routes::Channels(response_channels) = routes::list(State(app), viewer)
+ let routes::Channels(response_channels) = routes::list(State(app), viewer, Query::default())
.await
.expect("always succeeds");
diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs
index e2610a5..5733c9e 100644
--- a/src/channel/routes/test/on_create.rs
+++ b/src/channel/routes/test/on_create.rs
@@ -3,7 +3,7 @@ use futures::stream::StreamExt as _;
use crate::{
channel::{app, routes},
- events::types,
+ event,
test::fixtures::{self, future::Immediately as _},
};
@@ -33,26 +33,25 @@ async fn new_channel() {
// Verify the semantics
- let channels = app.channels().all().await.expect("always succeeds");
+ let channels = app.channels().all(None).await.expect("always succeeds");
assert!(channels.contains(&response_channel));
let mut events = app
.events()
- .subscribe(types::ResumePoint::default())
+ .subscribe(None)
.await
.expect("subscribing never fails")
.filter(fixtures::filter::created());
- let types::ResumableEvent(_, event) = events
+ let event = events
.next()
.immediately()
.await
.expect("creation event published");
- assert_eq!(types::Sequence::default(), event.sequence);
assert!(matches!(
- event.data,
- types::ChannelEventData::Created(event)
+ event.kind,
+ event::Kind::ChannelCreated(event)
if event.channel == response_channel
));
}
diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs
index 233518b..3297093 100644
--- a/src/channel/routes/test/on_send.rs
+++ b/src/channel/routes/test/on_send.rs
@@ -2,9 +2,10 @@ use axum::extract::{Json, Path, State};
use futures::stream::StreamExt;
use crate::{
+ channel,
channel::routes,
- events::{app, types},
- repo::channel,
+ event,
+ message::app::SendError,
test::fixtures::{self, future::Immediately as _},
};
@@ -43,7 +44,7 @@ async fn messages_in_order() {
let events = app
.events()
- .subscribe(types::ResumePoint::default())
+ .subscribe(None)
.await
.expect("subscribing to a valid channel")
.filter(fixtures::filter::messages())
@@ -51,13 +52,13 @@ async fn messages_in_order() {
let events = events.collect::<Vec<_>>().immediately().await;
- for ((sent_at, message), types::ResumableEvent(_, event)) in requests.into_iter().zip(events) {
- assert_eq!(*sent_at, event.at);
+ for ((sent_at, message), event) in requests.into_iter().zip(events) {
+ assert_eq!(*sent_at, event.instant.at);
assert!(matches!(
- event.data,
- types::ChannelEventData::Message(event_message)
- if event_message.sender == sender
- && event_message.message.body == message
+ event.kind,
+ event::Kind::MessageSent(event)
+ if event.message.sender == sender
+ && event.message.body == message
));
}
}
@@ -76,7 +77,7 @@ async fn nonexistent_channel() {
let request = routes::SendRequest {
message: fixtures::message::propose(),
};
- let routes::ErrorResponse(error) = routes::on_send(
+ let routes::SendErrorResponse(error) = routes::on_send(
State(app),
Path(channel.clone()),
sent_at,
@@ -90,6 +91,6 @@ async fn nonexistent_channel() {
assert!(matches!(
error,
- app::EventsError::ChannelNotFound(error_channel) if channel == error_channel
+ SendError::ChannelNotFound(error_channel) if channel == error_channel
));
}
diff --git a/src/channel/snapshot.rs b/src/channel/snapshot.rs
new file mode 100644
index 0000000..6462f25
--- /dev/null
+++ b/src/channel/snapshot.rs
@@ -0,0 +1,38 @@
+use super::{
+ event::{Created, Event, Kind},
+ Id,
+};
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Channel {
+ pub id: Id,
+ pub name: String,
+}
+
+impl Channel {
+ fn apply(state: Option<Self>, event: Event) -> Option<Self> {
+ match (state, event.kind) {
+ (None, Kind::Created(event)) => Some(event.into()),
+ (Some(channel), Kind::Deleted(event)) if channel.id == event.channel => None,
+ (state, event) => panic!("invalid channel event {event:#?} for state {state:#?}"),
+ }
+ }
+}
+
+impl FromIterator<Event> for Option<Channel> {
+ fn from_iter<I: IntoIterator<Item = Event>>(events: I) -> Self {
+ events.into_iter().fold(None, Channel::apply)
+ }
+}
+
+impl From<&Created> for Channel {
+ fn from(event: &Created) -> Self {
+ event.channel.clone()
+ }
+}
+
+impl From<Created> for Channel {
+ fn from(event: Created) -> Self {
+ event.channel
+ }
+}
diff --git a/src/cli.rs b/src/cli.rs
index 132baf8..2d9f512 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -10,7 +10,7 @@ use clap::Parser;
use sqlx::sqlite::SqlitePool;
use tokio::net;
-use crate::{app::App, channel, clock, events, expire, login, repo::pool};
+use crate::{app::App, channel, clock, db, event, expire, login, message};
/// Command-line entry point for running the `hi` server.
///
@@ -100,14 +100,19 @@ impl Args {
}
async fn pool(&self) -> sqlx::Result<SqlitePool> {
- pool::prepare(&self.database_url).await
+ db::prepare(&self.database_url).await
}
}
fn routers() -> Router<App> {
- [channel::router(), events::router(), login::router()]
- .into_iter()
- .fold(Router::default(), Router::merge)
+ [
+ channel::router(),
+ event::router(),
+ login::router(),
+ message::router(),
+ ]
+ .into_iter()
+ .fold(Router::default(), Router::merge)
}
fn started_msg(listener: &net::TcpListener) -> io::Result<String> {
diff --git a/src/repo/pool.rs b/src/db.rs
index b4aa6fc..93a1169 100644
--- a/src/repo/pool.rs
+++ b/src/db.rs
@@ -16,3 +16,27 @@ async fn create(database_url: &str) -> sqlx::Result<SqlitePool> {
let pool = SqlitePoolOptions::new().connect_with(options).await?;
Ok(pool)
}
+
+pub trait NotFound {
+ type Ok;
+ fn not_found<E, F>(self, map: F) -> Result<Self::Ok, E>
+ where
+ E: From<sqlx::Error>,
+ F: FnOnce() -> E;
+}
+
+impl<T> NotFound for Result<T, sqlx::Error> {
+ type Ok = T;
+
+ fn not_found<E, F>(self, map: F) -> Result<T, E>
+ where
+ E: From<sqlx::Error>,
+ F: FnOnce() -> E,
+ {
+ match self {
+ Err(sqlx::Error::RowNotFound) => Err(map()),
+ Err(other) => Err(other.into()),
+ Ok(value) => Ok(value),
+ }
+ }
+}
diff --git a/src/error.rs b/src/error.rs
index 6e797b4..8792a1d 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -61,3 +61,11 @@ impl fmt::Display for Id {
self.0.fmt(f)
}
}
+
+pub struct Unauthorized;
+
+impl IntoResponse for Unauthorized {
+ fn into_response(self) -> Response {
+ (StatusCode::UNAUTHORIZED, "unauthorized").into_response()
+ }
+}
diff --git a/src/event/app.rs b/src/event/app.rs
new file mode 100644
index 0000000..d664ec7
--- /dev/null
+++ b/src/event/app.rs
@@ -0,0 +1,72 @@
+use futures::{
+ future,
+ stream::{self, StreamExt as _},
+ Stream,
+};
+use itertools::Itertools as _;
+use sqlx::sqlite::SqlitePool;
+
+use super::{broadcaster::Broadcaster, Event, Sequence, Sequenced};
+use crate::{
+ channel::{self, repo::Provider as _},
+ message::{self, repo::Provider as _},
+};
+
+pub struct Events<'a> {
+ db: &'a SqlitePool,
+ events: &'a Broadcaster,
+}
+
+impl<'a> Events<'a> {
+ pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self {
+ Self { db, events }
+ }
+
+ pub async fn subscribe(
+ &self,
+ resume_at: Option<Sequence>,
+ ) -> Result<impl Stream<Item = Event> + std::fmt::Debug, sqlx::Error> {
+ // Subscribe before retrieving, to catch messages broadcast while we're
+ // querying the DB. We'll prune out duplicates later.
+ let live_messages = self.events.subscribe();
+
+ let mut tx = self.db.begin().await?;
+
+ let channels = tx.channels().replay(resume_at).await?;
+ let channel_events = channels
+ .iter()
+ .map(channel::History::events)
+ .kmerge_by(Sequence::merge)
+ .filter(Sequence::after(resume_at))
+ .map(Event::from);
+
+ let messages = tx.messages().replay(resume_at).await?;
+ let message_events = messages
+ .iter()
+ .map(message::History::events)
+ .kmerge_by(Sequence::merge)
+ .filter(Sequence::after(resume_at))
+ .map(Event::from);
+
+ let replay_events = channel_events
+ .merge_by(message_events, Sequence::merge)
+ .collect::<Vec<_>>();
+ let resume_live_at = replay_events.last().map(Sequenced::sequence);
+
+ let replay = stream::iter(replay_events);
+
+ let live_messages = live_messages
+ // Filtering on the broadcast resume point filters out messages
+ // before resume_at, and filters out messages duplicated from
+ // `replay_events`.
+ .flat_map(stream::iter)
+ .filter(Self::resume(resume_live_at));
+
+ Ok(replay.chain(live_messages))
+ }
+
+ fn resume(resume_at: Option<Sequence>) -> impl for<'m> FnMut(&'m Event) -> future::Ready<bool> {
+ let filter = Sequence::after(resume_at);
+ move |event| future::ready(filter(event))
+ }
+}
diff --git a/src/event/broadcaster.rs b/src/event/broadcaster.rs
new file mode 100644
index 0000000..3c4efac
--- /dev/null
+++ b/src/event/broadcaster.rs
@@ -0,0 +1,3 @@
+use crate::broadcast;
+
+pub type Broadcaster = broadcast::Broadcaster<Vec<super::Event>>;
diff --git a/src/events/extract.rs b/src/event/extract.rs
index e3021e2..e3021e2 100644
--- a/src/events/extract.rs
+++ b/src/event/extract.rs
diff --git a/src/event/mod.rs b/src/event/mod.rs
new file mode 100644
index 0000000..1349fe6
--- /dev/null
+++ b/src/event/mod.rs
@@ -0,0 +1,75 @@
+use crate::{channel, message};
+
+pub mod app;
+pub mod broadcaster;
+mod extract;
+pub mod repo;
+mod routes;
+mod sequence;
+
+pub use self::{
+ routes::router,
+ sequence::{Instant, Sequence, Sequenced},
+};
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Event {
+ #[serde(flatten)]
+ pub instant: Instant,
+ #[serde(flatten)]
+ pub kind: Kind,
+}
+
+impl Sequenced for Event {
+ fn instant(&self) -> Instant {
+ self.instant
+ }
+}
+
+impl From<channel::Event> for Event {
+ fn from(event: channel::Event) -> Self {
+ Self {
+ instant: event.instant,
+ kind: event.kind.into(),
+ }
+ }
+}
+
+impl From<message::Event> for Event {
+ fn from(event: message::Event) -> Self {
+ Self {
+ instant: event.instant(),
+ kind: event.kind.into(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum Kind {
+ #[serde(rename = "created")]
+ ChannelCreated(channel::event::Created),
+ #[serde(rename = "message")]
+ MessageSent(message::event::Sent),
+ MessageDeleted(message::event::Deleted),
+ #[serde(rename = "deleted")]
+ ChannelDeleted(channel::event::Deleted),
+}
+
+impl From<channel::event::Kind> for Kind {
+ fn from(kind: channel::event::Kind) -> Self {
+ match kind {
+ channel::event::Kind::Created(created) => Self::ChannelCreated(created),
+ channel::event::Kind::Deleted(deleted) => Self::ChannelDeleted(deleted),
+ }
+ }
+}
+
+impl From<message::event::Kind> for Kind {
+ fn from(kind: message::event::Kind) -> Self {
+ match kind {
+ message::event::Kind::Sent(created) => Self::MessageSent(created),
+ message::event::Kind::Deleted(deleted) => Self::MessageDeleted(deleted),
+ }
+ }
+}
diff --git a/src/event/repo.rs b/src/event/repo.rs
new file mode 100644
index 0000000..40d6a53
--- /dev/null
+++ b/src/event/repo.rs
@@ -0,0 +1,50 @@
+use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
+
+use crate::{
+ clock::DateTime,
+ event::{Instant, Sequence},
+};
+
+pub trait Provider {
+ fn sequence(&mut self) -> Sequences;
+}
+
+impl<'c> Provider for Transaction<'c, Sqlite> {
+ fn sequence(&mut self) -> Sequences {
+ Sequences(self)
+ }
+}
+
+pub struct Sequences<'t>(&'t mut SqliteConnection);
+
+impl<'c> Sequences<'c> {
+ pub async fn next(&mut self, at: &DateTime) -> Result<Instant, sqlx::Error> {
+ let next = sqlx::query_scalar!(
+ r#"
+ update event_sequence
+ set last_value = last_value + 1
+ returning last_value as "next_value: Sequence"
+ "#,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(Instant {
+ at: *at,
+ sequence: next,
+ })
+ }
+
+ pub async fn current(&mut self) -> Result<Sequence, sqlx::Error> {
+ let next = sqlx::query_scalar!(
+ r#"
+ select last_value as "last_value: Sequence"
+ from event_sequence
+ "#,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(next)
+ }
+}
diff --git a/src/event/routes.rs b/src/event/routes.rs
new file mode 100644
index 0000000..5b9c7e3
--- /dev/null
+++ b/src/event/routes.rs
@@ -0,0 +1,92 @@
+use axum::{
+ extract::State,
+ response::{
+ sse::{self, Sse},
+ IntoResponse, Response,
+ },
+ routing::get,
+ Router,
+};
+use axum_extra::extract::Query;
+use futures::stream::{Stream, StreamExt as _};
+
+use super::{extract::LastEventId, Event};
+use crate::{
+ app::App,
+ error::{Internal, Unauthorized},
+ event::{Sequence, Sequenced as _},
+ token::{app::ValidateError, extract::Identity},
+};
+
+#[cfg(test)]
+mod test;
+
+pub fn router() -> Router<App> {
+ Router::new().route("/api/events", get(events))
+}
+
+#[derive(Default, serde::Deserialize)]
+struct EventsQuery {
+ resume_point: Option<Sequence>,
+}
+
+async fn events(
+ State(app): State<App>,
+ identity: Identity,
+ last_event_id: Option<LastEventId<Sequence>>,
+ Query(query): Query<EventsQuery>,
+) -> Result<Events<impl Stream<Item = Event> + std::fmt::Debug>, EventsError> {
+ let resume_at = last_event_id
+ .map(LastEventId::into_inner)
+ .or(query.resume_point);
+
+ let stream = app.events().subscribe(resume_at).await?;
+ let stream = app.tokens().limit_stream(identity.token, stream).await?;
+
+ Ok(Events(stream))
+}
+
+#[derive(Debug)]
+struct Events<S>(S);
+
+impl<S> IntoResponse for Events<S>
+where
+ S: Stream<Item = Event> + Send + 'static,
+{
+ fn into_response(self) -> Response {
+ let Self(stream) = self;
+ let stream = stream.map(sse::Event::try_from);
+ Sse::new(stream)
+ .keep_alive(sse::KeepAlive::default())
+ .into_response()
+ }
+}
+
+impl TryFrom<Event> for sse::Event {
+ type Error = serde_json::Error;
+
+ fn try_from(event: Event) -> Result<Self, Self::Error> {
+ let id = serde_json::to_string(&event.sequence())?;
+ let data = serde_json::to_string_pretty(&event)?;
+
+ let event = Self::default().id(id).data(data);
+
+ Ok(event)
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum EventsError {
+ DatabaseError(#[from] sqlx::Error),
+ ValidateError(#[from] ValidateError),
+}
+
+impl IntoResponse for EventsError {
+ fn into_response(self) -> Response {
+ match self {
+ Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(),
+ other => Internal::from(other).into_response(),
+ }
+ }
+}
diff --git a/src/events/routes/test.rs b/src/event/routes/test.rs
index 820192d..ba9953e 100644
--- a/src/events/routes/test.rs
+++ b/src/event/routes/test.rs
@@ -1,11 +1,12 @@
use axum::extract::State;
+use axum_extra::extract::Query;
use futures::{
future,
stream::{self, StreamExt as _},
};
use crate::{
- events::{routes, types},
+ event::{routes, Sequenced as _},
test::fixtures::{self, future::Immediately as _},
};
@@ -16,26 +17,26 @@ async fn includes_historical_message() {
let app = fixtures::scratch_app().await;
let sender = fixtures::login::create(&app).await;
let channel = fixtures::channel::create(&app, &fixtures::now()).await;
- let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await;
+ let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await;
// Call the endpoint
let subscriber_creds = fixtures::login::create_with_password(&app).await;
let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
- let routes::Events(events) = routes::events(State(app), subscriber, None)
+ let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default())
.await
.expect("subscribe never fails");
// Verify the structure of the response.
- let types::ResumableEvent(_, event) = events
+ let event = events
.filter(fixtures::filter::messages())
.next()
.immediately()
.await
.expect("delivered stored message");
- assert_eq!(message, event);
+ assert!(fixtures::event::message_sent(&event, &message));
}
#[tokio::test]
@@ -49,23 +50,24 @@ async fn includes_live_message() {
let subscriber_creds = fixtures::login::create_with_password(&app).await;
let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
- let routes::Events(events) = routes::events(State(app.clone()), subscriber, None)
- .await
- .expect("subscribe never fails");
+ let routes::Events(events) =
+ routes::events(State(app.clone()), subscriber, None, Query::default())
+ .await
+ .expect("subscribe never fails");
// Verify the semantics
let sender = fixtures::login::create(&app).await;
- let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await;
+ let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await;
- let types::ResumableEvent(_, event) = events
+ let event = events
.filter(fixtures::filter::messages())
.next()
.immediately()
.await
.expect("delivered live message");
- assert_eq!(message, event);
+ assert!(fixtures::event::message_sent(&event, &message));
}
#[tokio::test]
@@ -85,7 +87,7 @@ async fn includes_multiple_channels() {
let app = app.clone();
let sender = sender.clone();
let channel = channel.clone();
- async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await }
+ async move { fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await }
})
.collect::<Vec<_>>()
.await;
@@ -94,7 +96,7 @@ async fn includes_multiple_channels() {
let subscriber_creds = fixtures::login::create_with_password(&app).await;
let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
- let routes::Events(events) = routes::events(State(app), subscriber, None)
+ let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default())
.await
.expect("subscribe never fails");
@@ -110,7 +112,7 @@ async fn includes_multiple_channels() {
for message in &messages {
assert!(events
.iter()
- .any(|types::ResumableEvent(_, event)| { event == message }));
+ .any(|event| fixtures::event::message_sent(event, message)));
}
}
@@ -123,33 +125,38 @@ async fn sequential_messages() {
let sender = fixtures::login::create(&app).await;
let messages = vec![
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
];
// Call the endpoint
let subscriber_creds = fixtures::login::create_with_password(&app).await;
let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
- let routes::Events(events) = routes::events(State(app), subscriber, None)
+ let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default())
.await
.expect("subscribe never fails");
// Verify the structure of the response.
- let mut events =
- events.filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event)));
+ let mut events = events.filter(|event| {
+ future::ready(
+ messages
+ .iter()
+ .any(|message| fixtures::event::message_sent(event, message)),
+ )
+ });
// Verify delivery in order
for message in &messages {
- let types::ResumableEvent(_, event) = events
+ let event = events
.next()
.immediately()
.await
.expect("undelivered messages remaining");
- assert_eq!(message, &event);
+ assert!(fixtures::event::message_sent(&event, message));
}
}
@@ -161,11 +168,11 @@ async fn resumes_from() {
let channel = fixtures::channel::create(&app, &fixtures::now()).await;
let sender = fixtures::login::create(&app).await;
- let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await;
+ let initial_message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await;
let later_messages = vec![
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
];
// Call the endpoint
@@ -175,26 +182,36 @@ async fn resumes_from() {
let resume_at = {
// First subscription
- let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None)
- .await
- .expect("subscribe never fails");
+ let routes::Events(events) = routes::events(
+ State(app.clone()),
+ subscriber.clone(),
+ None,
+ Query::default(),
+ )
+ .await
+ .expect("subscribe never fails");
- let types::ResumableEvent(last_event_id, event) = events
+ let event = events
.filter(fixtures::filter::messages())
.next()
.immediately()
.await
.expect("delivered events");
- assert_eq!(initial_message, event);
+ assert!(fixtures::event::message_sent(&event, &initial_message));
- last_event_id
+ event.sequence()
};
// Resume after disconnect
- let routes::Events(resumed) = routes::events(State(app), subscriber, Some(resume_at.into()))
- .await
- .expect("subscribe never fails");
+ let routes::Events(resumed) = routes::events(
+ State(app),
+ subscriber,
+ Some(resume_at.into()),
+ Query::default(),
+ )
+ .await
+ .expect("subscribe never fails");
// Verify the structure of the response.
@@ -207,7 +224,7 @@ async fn resumes_from() {
for message in &later_messages {
assert!(events
.iter()
- .any(|types::ResumableEvent(_, event)| event == message));
+ .any(|event| fixtures::event::message_sent(event, message)));
}
}
@@ -242,14 +259,19 @@ async fn serial_resume() {
let resume_at = {
let initial_messages = [
- fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel_b, &sender, &fixtures::now()).await,
];
// First subscription
- let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None)
- .await
- .expect("subscribe never fails");
+ let routes::Events(events) = routes::events(
+ State(app.clone()),
+ subscriber.clone(),
+ None,
+ Query::default(),
+ )
+ .await
+ .expect("subscribe never fails");
let events = events
.filter(fixtures::filter::messages())
@@ -261,12 +283,12 @@ async fn serial_resume() {
for message in &initial_messages {
assert!(events
.iter()
- .any(|types::ResumableEvent(_, event)| event == message));
+ .any(|event| fixtures::event::message_sent(event, message)));
}
- let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty");
+ let event = events.last().expect("this vec is non-empty");
- id.to_owned()
+ event.sequence()
};
// Resume after disconnect
@@ -275,8 +297,8 @@ async fn serial_resume() {
// Note that channel_b does not appear here. The buggy behaviour
// would be masked if channel_b happened to send a new message
// into the resumed event stream.
- fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await,
];
// Second subscription
@@ -284,6 +306,7 @@ async fn serial_resume() {
State(app.clone()),
subscriber.clone(),
Some(resume_at.into()),
+ Query::default(),
)
.await
.expect("subscribe never fails");
@@ -298,12 +321,12 @@ async fn serial_resume() {
for message in &resume_messages {
assert!(events
.iter()
- .any(|types::ResumableEvent(_, event)| event == message));
+ .any(|event| fixtures::event::message_sent(event, message)));
}
- let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty");
+ let event = events.last().expect("this vec is non-empty");
- id.to_owned()
+ event.sequence()
};
// Resume after disconnect a second time
@@ -312,8 +335,8 @@ async fn serial_resume() {
// problem. The resume point should before both of these messages, but
// after _all_ prior messages.
let final_messages = [
- fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel_b, &sender, &fixtures::now()).await,
];
// Third subscription
@@ -321,6 +344,7 @@ async fn serial_resume() {
State(app.clone()),
subscriber.clone(),
Some(resume_at.into()),
+ Query::default(),
)
.await
.expect("subscribe never fails");
@@ -337,7 +361,7 @@ async fn serial_resume() {
for message in &final_messages {
assert!(events
.iter()
- .any(|types::ResumableEvent(_, event)| event == message));
+ .any(|event| fixtures::event::message_sent(event, message)));
}
};
}
@@ -356,26 +380,31 @@ async fn terminates_on_token_expiry() {
let subscriber =
fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await;
- let routes::Events(events) = routes::events(State(app.clone()), subscriber, None)
- .await
- .expect("subscribe never fails");
+ let routes::Events(events) =
+ routes::events(State(app.clone()), subscriber, None, Query::default())
+ .await
+ .expect("subscribe never fails");
// Verify the resulting stream's behaviour
- app.logins()
+ app.tokens()
.expire(&fixtures::now())
.await
.expect("expiring tokens succeeds");
// These should not be delivered.
let messages = [
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
];
assert!(events
- .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event)))
+ .filter(|event| future::ready(
+ messages
+ .iter()
+ .any(|message| fixtures::event::message_sent(event, message))
+ ))
.next()
.immediately()
.await
@@ -398,26 +427,35 @@ async fn terminates_on_logout() {
let subscriber =
fixtures::identity::from_token(&app, &subscriber_token, &fixtures::now()).await;
- let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None)
- .await
- .expect("subscribe never fails");
+ let routes::Events(events) = routes::events(
+ State(app.clone()),
+ subscriber.clone(),
+ None,
+ Query::default(),
+ )
+ .await
+ .expect("subscribe never fails");
// Verify the resulting stream's behaviour
- app.logins()
+ app.tokens()
.logout(&subscriber.token)
.await
.expect("expiring tokens succeeds");
// These should not be delivered.
let messages = [
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
- fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
+ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await,
];
assert!(events
- .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event)))
+ .filter(|event| future::ready(
+ messages
+ .iter()
+ .any(|message| fixtures::event::message_sent(event, message))
+ ))
.next()
.immediately()
.await
diff --git a/src/event/sequence.rs b/src/event/sequence.rs
new file mode 100644
index 0000000..fbe3711
--- /dev/null
+++ b/src/event/sequence.rs
@@ -0,0 +1,90 @@
+use std::fmt;
+
+use crate::clock::DateTime;
+
+#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Instant {
+ pub at: DateTime,
+ #[serde(skip)]
+ pub sequence: Sequence,
+}
+
+impl From<Instant> for Sequence {
+ fn from(instant: Instant) -> Self {
+ instant.sequence
+ }
+}
+
+#[derive(
+ Clone,
+ Copy,
+ Debug,
+ Eq,
+ Ord,
+ PartialEq,
+ PartialOrd,
+ serde::Deserialize,
+ serde::Serialize,
+ sqlx::Type,
+)]
+#[serde(transparent)]
+#[sqlx(transparent)]
+pub struct Sequence(i64);
+
+impl fmt::Display for Sequence {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let Self(value) = self;
+ value.fmt(f)
+ }
+}
+
+impl Sequence {
+ pub fn up_to<E>(resume_point: Option<Self>) -> impl for<'e> Fn(&'e E) -> bool
+ where
+ E: Sequenced,
+ {
+ move |event| resume_point.map_or(true, |resume_point| event.sequence() <= resume_point)
+ }
+
+ pub fn after<E>(resume_point: Option<Self>) -> impl for<'e> Fn(&'e E) -> bool
+ where
+ E: Sequenced,
+ {
+ move |event| resume_point < Some(event.sequence())
+ }
+
+ pub fn start_from<E>(resume_point: Self) -> impl for<'e> Fn(&'e E) -> bool
+ where
+ E: Sequenced,
+ {
+ move |event| resume_point <= event.sequence()
+ }
+
+ pub fn merge<E>(a: &E, b: &E) -> bool
+ where
+ E: Sequenced,
+ {
+ a.sequence() < b.sequence()
+ }
+}
+
+pub trait Sequenced {
+ fn instant(&self) -> Instant;
+
+ fn sequence(&self) -> Sequence {
+ self.instant().into()
+ }
+}
+
+impl<E> Sequenced for &E
+where
+ E: Sequenced,
+{
+ fn instant(&self) -> Instant {
+ (*self).instant()
+ }
+
+ fn sequence(&self) -> Sequence {
+ (*self).sequence()
+ }
+}
diff --git a/src/events/app.rs b/src/events/app.rs
deleted file mode 100644
index db7f430..0000000
--- a/src/events/app.rs
+++ /dev/null
@@ -1,163 +0,0 @@
-use std::collections::BTreeMap;
-
-use chrono::TimeDelta;
-use futures::{
- future,
- stream::{self, StreamExt as _},
- Stream,
-};
-use sqlx::sqlite::SqlitePool;
-
-use super::{
- broadcaster::Broadcaster,
- repo::message::Provider as _,
- types::{self, ChannelEvent, ResumePoint},
-};
-use crate::{
- clock::DateTime,
- repo::{
- channel::{self, Provider as _},
- error::NotFound as _,
- login::Login,
- },
-};
-
-pub struct Events<'a> {
- db: &'a SqlitePool,
- events: &'a Broadcaster,
-}
-
-impl<'a> Events<'a> {
- pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self {
- Self { db, events }
- }
-
- pub async fn send(
- &self,
- login: &Login,
- channel: &channel::Id,
- body: &str,
- sent_at: &DateTime,
- ) -> Result<types::ChannelEvent, EventsError> {
- let mut tx = self.db.begin().await?;
- let channel = tx
- .channels()
- .by_id(channel)
- .await
- .not_found(|| EventsError::ChannelNotFound(channel.clone()))?;
- let event = tx
- .message_events()
- .create(login, &channel, body, sent_at)
- .await?;
- tx.commit().await?;
-
- self.events.broadcast(&event);
- Ok(event)
- }
-
- pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> {
- // Somewhat arbitrarily, expire after 90 days.
- let expire_at = relative_to.to_owned() - TimeDelta::days(90);
-
- let mut tx = self.db.begin().await?;
- let expired = tx.message_events().expired(&expire_at).await?;
-
- let mut events = Vec::with_capacity(expired.len());
- for (channel, message) in expired {
- let sequence = tx.message_events().assign_sequence(&channel).await?;
- let event = tx
- .message_events()
- .delete_expired(&channel, &message, sequence, relative_to)
- .await?;
- events.push(event);
- }
-
- tx.commit().await?;
-
- for event in events {
- self.events.broadcast(&event);
- }
-
- Ok(())
- }
-
- pub async fn subscribe(
- &self,
- resume_at: ResumePoint,
- ) -> Result<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug, sqlx::Error> {
- let mut tx = self.db.begin().await?;
- let channels = tx.channels().all().await?;
-
- let created_events = {
- let resume_at = resume_at.clone();
- let channels = channels.clone();
- stream::iter(
- channels
- .into_iter()
- .map(ChannelEvent::created)
- .filter(move |event| resume_at.not_after(event)),
- )
- };
-
- // Subscribe before retrieving, to catch messages broadcast while we're
- // querying the DB. We'll prune out duplicates later.
- let live_messages = self.events.subscribe();
-
- let mut replays = BTreeMap::new();
- let mut resume_live_at = resume_at.clone();
- for channel in channels {
- let replay = tx
- .message_events()
- .replay(&channel, resume_at.get(&channel.id))
- .await?;
-
- if let Some(last) = replay.last() {
- resume_live_at.advance(last);
- }
-
- replays.insert(channel.id.clone(), replay);
- }
-
- let replay = stream::select_all(replays.into_values().map(stream::iter));
-
- // no skip_expired or resume transforms for stored_messages, as it's
- // constructed not to contain messages meeting either criterion.
- //
- // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call;
- // * resume is redundant with the resume_at argument to
- // `tx.broadcasts().replay(…)`.
- let live_messages = live_messages
- // Filtering on the broadcast resume point filters out messages
- // before resume_at, and filters out messages duplicated from
- // stored_messages.
- .filter(Self::resume(resume_live_at));
-
- Ok(created_events.chain(replay).chain(live_messages).scan(
- resume_at,
- |resume_point, event| {
- match event.data {
- types::ChannelEventData::Deleted(_) => resume_point.forget(&event),
- _ => resume_point.advance(&event),
- }
-
- let event = types::ResumableEvent(resume_point.clone(), event);
-
- future::ready(Some(event))
- },
- ))
- }
-
- fn resume(
- resume_at: ResumePoint,
- ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> {
- move |event| future::ready(resume_at.not_after(event))
- }
-}
-
-#[derive(Debug, thiserror::Error)]
-pub enum EventsError {
- #[error("channel {0} not found")]
- ChannelNotFound(channel::Id),
- #[error(transparent)]
- DatabaseError(#[from] sqlx::Error),
-}
diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs
deleted file mode 100644
index 6b664cb..0000000
--- a/src/events/broadcaster.rs
+++ /dev/null
@@ -1,3 +0,0 @@
-use crate::{broadcast, events::types};
-
-pub type Broadcaster = broadcast::Broadcaster<types::ChannelEvent>;
diff --git a/src/events/mod.rs b/src/events/mod.rs
deleted file mode 100644
index 711ae64..0000000
--- a/src/events/mod.rs
+++ /dev/null
@@ -1,8 +0,0 @@
-pub mod app;
-pub mod broadcaster;
-mod extract;
-pub mod repo;
-mod routes;
-pub mod types;
-
-pub use self::routes::router;
diff --git a/src/events/repo/message.rs b/src/events/repo/message.rs
deleted file mode 100644
index f8bae2b..0000000
--- a/src/events/repo/message.rs
+++ /dev/null
@@ -1,198 +0,0 @@
-use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
-
-use crate::{
- clock::DateTime,
- events::types::{self, Sequence},
- repo::{
- channel::{self, Channel},
- login::{self, Login},
- message::{self, Message},
- },
-};
-
-pub trait Provider {
- fn message_events(&mut self) -> Events;
-}
-
-impl<'c> Provider for Transaction<'c, Sqlite> {
- fn message_events(&mut self) -> Events {
- Events(self)
- }
-}
-
-pub struct Events<'t>(&'t mut SqliteConnection);
-
-impl<'c> Events<'c> {
- pub async fn create(
- &mut self,
- sender: &Login,
- channel: &Channel,
- body: &str,
- sent_at: &DateTime,
- ) -> Result<types::ChannelEvent, sqlx::Error> {
- let sequence = self.assign_sequence(channel).await?;
-
- let id = message::Id::generate();
-
- let message = sqlx::query!(
- r#"
- insert into message
- (id, channel, sequence, sender, body, sent_at)
- values ($1, $2, $3, $4, $5, $6)
- returning
- id as "id: message::Id",
- sequence as "sequence: Sequence",
- sender as "sender: login::Id",
- body,
- sent_at as "sent_at: DateTime"
- "#,
- id,
- channel.id,
- sequence,
- sender.id,
- body,
- sent_at,
- )
- .map(|row| types::ChannelEvent {
- sequence: row.sequence,
- at: row.sent_at,
- data: types::MessageEvent {
- channel: channel.clone(),
- sender: sender.clone(),
- message: Message {
- id: row.id,
- body: row.body,
- },
- }
- .into(),
- })
- .fetch_one(&mut *self.0)
- .await?;
-
- Ok(message)
- }
-
- pub async fn assign_sequence(&mut self, channel: &Channel) -> Result<Sequence, sqlx::Error> {
- let next = sqlx::query_scalar!(
- r#"
- update channel
- set last_sequence = last_sequence + 1
- where id = $1
- returning last_sequence as "next_sequence: Sequence"
- "#,
- channel.id,
- )
- .fetch_one(&mut *self.0)
- .await?;
-
- Ok(next)
- }
-
- pub async fn delete_expired(
- &mut self,
- channel: &Channel,
- message: &message::Id,
- sequence: Sequence,
- deleted_at: &DateTime,
- ) -> Result<types::ChannelEvent, sqlx::Error> {
- sqlx::query_scalar!(
- r#"
- delete from message
- where id = $1
- returning 1 as "row: i64"
- "#,
- message,
- )
- .fetch_one(&mut *self.0)
- .await?;
-
- Ok(types::ChannelEvent {
- sequence,
- at: *deleted_at,
- data: types::MessageDeletedEvent {
- channel: channel.clone(),
- message: message.clone(),
- }
- .into(),
- })
- }
-
- pub async fn expired(
- &mut self,
- expire_at: &DateTime,
- ) -> Result<Vec<(Channel, message::Id)>, sqlx::Error> {
- let messages = sqlx::query!(
- r#"
- select
- channel.id as "channel_id: channel::Id",
- channel.name as "channel_name",
- channel.created_at as "channel_created_at: DateTime",
- message.id as "message: message::Id"
- from message
- join channel on message.channel = channel.id
- join login as sender on message.sender = sender.id
- where sent_at < $1
- "#,
- expire_at,
- )
- .map(|row| {
- (
- Channel {
- id: row.channel_id,
- name: row.channel_name,
- created_at: row.channel_created_at,
- },
- row.message,
- )
- })
- .fetch_all(&mut *self.0)
- .await?;
-
- Ok(messages)
- }
-
- pub async fn replay(
- &mut self,
- channel: &Channel,
- resume_at: Option<Sequence>,
- ) -> Result<Vec<types::ChannelEvent>, sqlx::Error> {
- let events = sqlx::query!(
- r#"
- select
- message.id as "id: message::Id",
- sequence as "sequence: Sequence",
- login.id as "sender_id: login::Id",
- login.name as sender_name,
- message.body,
- message.sent_at as "sent_at: DateTime"
- from message
- join login on message.sender = login.id
- where channel = $1
- and coalesce(sequence > $2, true)
- order by sequence asc
- "#,
- channel.id,
- resume_at,
- )
- .map(|row| types::ChannelEvent {
- sequence: row.sequence,
- at: row.sent_at,
- data: types::MessageEvent {
- channel: channel.clone(),
- sender: login::Login {
- id: row.sender_id,
- name: row.sender_name,
- },
- message: Message {
- id: row.id,
- body: row.body,
- },
- }
- .into(),
- })
- .fetch_all(&mut *self.0)
- .await?;
-
- Ok(events)
- }
-}
diff --git a/src/events/repo/mod.rs b/src/events/repo/mod.rs
deleted file mode 100644
index e216a50..0000000
--- a/src/events/repo/mod.rs
+++ /dev/null
@@ -1 +0,0 @@
-pub mod message;
diff --git a/src/events/routes.rs b/src/events/routes.rs
deleted file mode 100644
index ec9dae2..0000000
--- a/src/events/routes.rs
+++ /dev/null
@@ -1,69 +0,0 @@
-use axum::{
- extract::State,
- response::{
- sse::{self, Sse},
- IntoResponse, Response,
- },
- routing::get,
- Router,
-};
-use futures::stream::{Stream, StreamExt as _};
-
-use super::{
- extract::LastEventId,
- types::{self, ResumePoint},
-};
-use crate::{app::App, error::Internal, login::extract::Identity};
-
-#[cfg(test)]
-mod test;
-
-pub fn router() -> Router<App> {
- Router::new().route("/api/events", get(events))
-}
-
-async fn events(
- State(app): State<App>,
- identity: Identity,
- last_event_id: Option<LastEventId<ResumePoint>>,
-) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> {
- let resume_at = last_event_id
- .map(LastEventId::into_inner)
- .unwrap_or_default();
-
- let stream = app.events().subscribe(resume_at).await?;
- let stream = app.logins().limit_stream(identity.token, stream);
-
- Ok(Events(stream))
-}
-
-#[derive(Debug)]
-struct Events<S>(S);
-
-impl<S> IntoResponse for Events<S>
-where
- S: Stream<Item = types::ResumableEvent> + Send + 'static,
-{
- fn into_response(self) -> Response {
- let Self(stream) = self;
- let stream = stream.map(sse::Event::try_from);
- Sse::new(stream)
- .keep_alive(sse::KeepAlive::default())
- .into_response()
- }
-}
-
-impl TryFrom<types::ResumableEvent> for sse::Event {
- type Error = serde_json::Error;
-
- fn try_from(value: types::ResumableEvent) -> Result<Self, Self::Error> {
- let types::ResumableEvent(resume_at, data) = value;
-
- let id = serde_json::to_string(&resume_at)?;
- let data = serde_json::to_string_pretty(&data)?;
-
- let event = Self::default().id(id).data(data);
-
- Ok(event)
- }
-}
diff --git a/src/events/types.rs b/src/events/types.rs
deleted file mode 100644
index d954512..0000000
--- a/src/events/types.rs
+++ /dev/null
@@ -1,170 +0,0 @@
-use std::collections::BTreeMap;
-
-use crate::{
- clock::DateTime,
- repo::{
- channel::{self, Channel},
- login::Login,
- message,
- },
-};
-
-#[derive(
- Debug,
- Default,
- Eq,
- Ord,
- PartialEq,
- PartialOrd,
- Clone,
- Copy,
- serde::Serialize,
- serde::Deserialize,
- sqlx::Type,
-)]
-#[serde(transparent)]
-#[sqlx(transparent)]
-pub struct Sequence(i64);
-
-impl Sequence {
- pub fn next(self) -> Self {
- let Self(current) = self;
- Self(current + 1)
- }
-}
-
-// For the purposes of event replay, a resume point is a vector of resume
-// elements. A resume element associates a channel (by ID) with the latest event
-// seen in that channel so far. Replaying the event stream can restart at a
-// predictable point - hence the name. These values can be serialized and sent
-// to the client as JSON dicts, then rehydrated to recover the resume point at a
-// later time.
-//
-// Using a sorted map ensures that there is a canonical representation for
-// each resume point.
-#[derive(Clone, Debug, Default, PartialEq, PartialOrd, serde::Deserialize, serde::Serialize)]
-#[serde(transparent)]
-pub struct ResumePoint(BTreeMap<channel::Id, Sequence>);
-
-impl ResumePoint {
- pub fn advance<'e>(&mut self, event: impl Into<ResumeElement<'e>>) {
- let Self(elements) = self;
- let ResumeElement(channel, sequence) = event.into();
- elements.insert(channel.clone(), sequence);
- }
-
- pub fn forget<'e>(&mut self, event: impl Into<ResumeElement<'e>>) {
- let Self(elements) = self;
- let ResumeElement(channel, _) = event.into();
- elements.remove(channel);
- }
-
- pub fn get(&self, channel: &channel::Id) -> Option<Sequence> {
- let Self(elements) = self;
- elements.get(channel).copied()
- }
-
- pub fn not_after<'e>(&self, event: impl Into<ResumeElement<'e>>) -> bool {
- let Self(elements) = self;
- let ResumeElement(channel, sequence) = event.into();
-
- elements
- .get(channel)
- .map_or(true, |resume_at| resume_at < &sequence)
- }
-}
-
-pub struct ResumeElement<'i>(&'i channel::Id, Sequence);
-
-#[derive(Clone, Debug)]
-pub struct ResumableEvent(pub ResumePoint, pub ChannelEvent);
-
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct ChannelEvent {
- #[serde(skip)]
- pub sequence: Sequence,
- pub at: DateTime,
- #[serde(flatten)]
- pub data: ChannelEventData,
-}
-
-impl ChannelEvent {
- pub fn created(channel: Channel) -> Self {
- Self {
- at: channel.created_at,
- sequence: Sequence::default(),
- data: CreatedEvent { channel }.into(),
- }
- }
-
- pub fn channel_id(&self) -> &channel::Id {
- match &self.data {
- ChannelEventData::Created(event) => &event.channel.id,
- ChannelEventData::Message(event) => &event.channel.id,
- ChannelEventData::MessageDeleted(event) => &event.channel.id,
- ChannelEventData::Deleted(event) => &event.channel,
- }
- }
-}
-
-impl<'c> From<&'c ChannelEvent> for ResumeElement<'c> {
- fn from(event: &'c ChannelEvent) -> Self {
- Self(event.channel_id(), event.sequence)
- }
-}
-
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum ChannelEventData {
- Created(CreatedEvent),
- Message(MessageEvent),
- MessageDeleted(MessageDeletedEvent),
- Deleted(DeletedEvent),
-}
-
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct CreatedEvent {
- pub channel: Channel,
-}
-
-impl From<CreatedEvent> for ChannelEventData {
- fn from(event: CreatedEvent) -> Self {
- Self::Created(event)
- }
-}
-
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct MessageEvent {
- pub channel: Channel,
- pub sender: Login,
- pub message: message::Message,
-}
-
-impl From<MessageEvent> for ChannelEventData {
- fn from(event: MessageEvent) -> Self {
- Self::Message(event)
- }
-}
-
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct MessageDeletedEvent {
- pub channel: Channel,
- pub message: message::Id,
-}
-
-impl From<MessageDeletedEvent> for ChannelEventData {
- fn from(event: MessageDeletedEvent) -> Self {
- Self::MessageDeleted(event)
- }
-}
-
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct DeletedEvent {
- pub channel: channel::Id,
-}
-
-impl From<DeletedEvent> for ChannelEventData {
- fn from(event: DeletedEvent) -> Self {
- Self::Deleted(event)
- }
-}
diff --git a/src/expire.rs b/src/expire.rs
index 16006d1..e50bcb4 100644
--- a/src/expire.rs
+++ b/src/expire.rs
@@ -13,8 +13,8 @@ pub async fn middleware(
req: Request,
next: Next,
) -> Result<Response, Internal> {
- app.logins().expire(&expired_at).await?;
- app.events().expire(&expired_at).await?;
+ app.tokens().expire(&expired_at).await?;
+ app.messages().expire(&expired_at).await?;
app.channels().expire(&expired_at).await?;
Ok(next.run(req).await)
}
diff --git a/src/lib.rs b/src/lib.rs
index 271118b..8ec13da 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -7,12 +7,13 @@ mod broadcast;
mod channel;
pub mod cli;
mod clock;
+mod db;
mod error;
-mod events;
+mod event;
mod expire;
mod id;
mod login;
-mod password;
-mod repo;
+mod message;
#[cfg(test)]
mod test;
+mod token;
diff --git a/src/login/app.rs b/src/login/app.rs
index 182c62c..15adb31 100644
--- a/src/login/app.rs
+++ b/src/login/app.rs
@@ -1,57 +1,25 @@
-use chrono::TimeDelta;
-use futures::{
- future,
- stream::{self, StreamExt as _},
- Stream,
-};
use sqlx::sqlite::SqlitePool;
-use super::{broadcaster::Broadcaster, extract::IdentitySecret, repo::auth::Provider as _, types};
-use crate::{
- clock::DateTime,
- password::Password,
- repo::{
- error::NotFound as _,
- login::{Login, Provider as _},
- token::{self, Provider as _},
- },
-};
+use crate::event::{repo::Provider as _, Sequence};
+
+#[cfg(test)]
+use super::{repo::Provider as _, Login, Password};
pub struct Logins<'a> {
db: &'a SqlitePool,
- logins: &'a Broadcaster,
}
impl<'a> Logins<'a> {
- pub const fn new(db: &'a SqlitePool, logins: &'a Broadcaster) -> Self {
- Self { db, logins }
+ pub const fn new(db: &'a SqlitePool) -> Self {
+ Self { db }
}
- pub async fn login(
- &self,
- name: &str,
- password: &Password,
- login_at: &DateTime,
- ) -> Result<IdentitySecret, LoginError> {
+ pub async fn boot_point(&self) -> Result<Sequence, sqlx::Error> {
let mut tx = self.db.begin().await?;
-
- let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? {
- if stored_hash.verify(password)? {
- // Password verified; use the login.
- login
- } else {
- // Password NOT verified.
- return Err(LoginError::Rejected);
- }
- } else {
- let password_hash = password.hash()?;
- tx.logins().create(name, &password_hash).await?
- };
-
- let token = tx.tokens().issue(&login, login_at).await?;
+ let sequence = tx.sequence().current().await?;
tx.commit().await?;
- Ok(token)
+ Ok(sequence)
}
#[cfg(test)]
@@ -64,82 +32,6 @@ impl<'a> Logins<'a> {
Ok(login)
}
-
- pub async fn validate(
- &self,
- secret: &IdentitySecret,
- used_at: &DateTime,
- ) -> Result<(token::Id, Login), ValidateError> {
- let mut tx = self.db.begin().await?;
- let login = tx
- .tokens()
- .validate(secret, used_at)
- .await
- .not_found(|| ValidateError::InvalidToken)?;
- tx.commit().await?;
-
- Ok(login)
- }
-
- pub fn limit_stream<E>(
- &self,
- token: token::Id,
- events: impl Stream<Item = E> + std::fmt::Debug,
- ) -> impl Stream<Item = E> + std::fmt::Debug
- where
- E: std::fmt::Debug,
- {
- let token_events = self
- .logins
- .subscribe()
- .filter(move |event| future::ready(event.token == token))
- .map(|_| GuardedEvent::TokenRevoked);
-
- let events = events.map(|event| GuardedEvent::Event(event));
-
- stream::select(token_events, events).scan((), |(), event| {
- future::ready(match event {
- GuardedEvent::Event(event) => Some(event),
- GuardedEvent::TokenRevoked => None,
- })
- })
- }
-
- pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> {
- // Somewhat arbitrarily, expire after 7 days.
- let expire_at = relative_to.to_owned() - TimeDelta::days(7);
-
- let mut tx = self.db.begin().await?;
- let tokens = tx.tokens().expire(&expire_at).await?;
- tx.commit().await?;
-
- for event in tokens.into_iter().map(types::TokenRevoked::from) {
- self.logins.broadcast(&event);
- }
-
- Ok(())
- }
-
- pub async fn logout(&self, token: &token::Id) -> Result<(), ValidateError> {
- let mut tx = self.db.begin().await?;
- tx.tokens().revoke(token).await?;
- tx.commit().await?;
-
- self.logins
- .broadcast(&types::TokenRevoked::from(token.clone()));
-
- Ok(())
- }
-}
-
-#[derive(Debug, thiserror::Error)]
-pub enum LoginError {
- #[error("invalid login")]
- Rejected,
- #[error(transparent)]
- DatabaseError(#[from] sqlx::Error),
- #[error(transparent)]
- PasswordHashError(#[from] password_hash::Error),
}
#[cfg(test)]
@@ -149,17 +41,3 @@ pub enum CreateError {
DatabaseError(#[from] sqlx::Error),
PasswordHashError(#[from] password_hash::Error),
}
-
-#[derive(Debug, thiserror::Error)]
-pub enum ValidateError {
- #[error("invalid token")]
- InvalidToken,
- #[error(transparent)]
- DatabaseError(#[from] sqlx::Error),
-}
-
-#[derive(Debug)]
-enum GuardedEvent<E> {
- TokenRevoked,
- Event(E),
-}
diff --git a/src/login/broadcaster.rs b/src/login/broadcaster.rs
deleted file mode 100644
index 8e1fb3a..0000000
--- a/src/login/broadcaster.rs
+++ /dev/null
@@ -1,3 +0,0 @@
-use crate::{broadcast, login::types};
-
-pub type Broadcaster = broadcast::Broadcaster<types::TokenRevoked>;
diff --git a/src/login/extract.rs b/src/login/extract.rs
index b585565..c2d97f2 100644
--- a/src/login/extract.rs
+++ b/src/login/extract.rs
@@ -1,182 +1,15 @@
-use std::fmt;
+use axum::{extract::FromRequestParts, http::request::Parts};
-use axum::{
- extract::{FromRequestParts, State},
- http::{request::Parts, StatusCode},
- response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
-};
-use axum_extra::extract::cookie::{Cookie, CookieJar};
-
-use crate::{
- app::App,
- clock::RequestedAt,
- error::Internal,
- login::app::ValidateError,
- repo::{login::Login, token},
-};
-
-// The usage pattern here - receive the extractor as an argument, return it in
-// the response - is heavily modelled after CookieJar's own intended usage.
-#[derive(Clone)]
-pub struct IdentityToken {
- cookies: CookieJar,
-}
-
-impl fmt::Debug for IdentityToken {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_struct("IdentityToken")
- .field(
- "identity",
- &self.cookies.get(IDENTITY_COOKIE).map(|_| "********"),
- )
- .finish()
- }
-}
-
-impl IdentityToken {
- // Creates a new, unpopulated identity token store.
- #[cfg(test)]
- pub fn new() -> Self {
- Self {
- cookies: CookieJar::new(),
- }
- }
-
- // Get the identity secret sent in the request, if any. If the identity
- // was not sent, or if it has previously been [clear]ed, then this will
- // return [None]. If the identity has previously been [set], then this
- // will return that secret, regardless of what the request originally
- // included.
- pub fn secret(&self) -> Option<IdentitySecret> {
- self.cookies
- .get(IDENTITY_COOKIE)
- .map(Cookie::value)
- .map(IdentitySecret::from)
- }
-
- // Positively set the identity secret, and ensure that it will be sent
- // back to the client when this extractor is included in a response.
- pub fn set(self, secret: impl Into<IdentitySecret>) -> Self {
- let IdentitySecret(secret) = secret.into();
- let identity_cookie = Cookie::build((IDENTITY_COOKIE, secret))
- .http_only(true)
- .path("/api/")
- .permanent()
- .build();
-
- Self {
- cookies: self.cookies.add(identity_cookie),
- }
- }
-
- // Remove the identity secret and ensure that it will be cleared when this
- // extractor is included in a response.
- pub fn clear(self) -> Self {
- Self {
- cookies: self.cookies.remove(IDENTITY_COOKIE),
- }
- }
-}
-
-const IDENTITY_COOKIE: &str = "identity";
-
-#[async_trait::async_trait]
-impl<S> FromRequestParts<S> for IdentityToken
-where
- S: Send + Sync,
-{
- type Rejection = <CookieJar as FromRequestParts<S>>::Rejection;
-
- async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
- let cookies = CookieJar::from_request_parts(parts, state).await?;
- Ok(Self { cookies })
- }
-}
-
-impl IntoResponseParts for IdentityToken {
- type Error = <CookieJar as IntoResponseParts>::Error;
-
- fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
- let Self { cookies } = self;
- cookies.into_response_parts(res)
- }
-}
-
-#[derive(sqlx::Type)]
-#[sqlx(transparent)]
-pub struct IdentitySecret(String);
-
-impl fmt::Debug for IdentitySecret {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_tuple("IdentityToken").field(&"********").finish()
- }
-}
-
-impl<S> From<S> for IdentitySecret
-where
- S: Into<String>,
-{
- fn from(value: S) -> Self {
- Self(value.into())
- }
-}
-
-#[derive(Clone, Debug)]
-pub struct Identity {
- pub token: token::Id,
- pub login: Login,
-}
+use super::Login;
+use crate::{app::App, token::extract::Identity};
#[async_trait::async_trait]
-impl FromRequestParts<App> for Identity {
- type Rejection = LoginError<Internal>;
+impl FromRequestParts<App> for Login {
+ type Rejection = <Identity as FromRequestParts<App>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> {
- // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on
- // stable), the following can be replaced:
- //
- // ```
- // let Ok(identity_token) = IdentityToken::from_request_parts(
- // parts,
- // state,
- // ).await;
- // ```
- let identity_token = IdentityToken::from_request_parts(parts, state).await?;
- let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?;
-
- let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?;
-
- let app = State::<App>::from_request_parts(parts, state).await?;
- match app.logins().validate(&secret, &used_at).await {
- Ok((token, login)) => Ok(Identity { token, login }),
- Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized),
- Err(other) => Err(other.into()),
- }
- }
-}
-
-pub enum LoginError<E> {
- Failure(E),
- Unauthorized,
-}
-
-impl<E> IntoResponse for LoginError<E>
-where
- E: IntoResponse,
-{
- fn into_response(self) -> Response {
- match self {
- Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(),
- Self::Failure(e) => e.into_response(),
- }
- }
-}
+ let identity = Identity::from_request_parts(parts, state).await?;
-impl<E> From<E> for LoginError<Internal>
-where
- E: Into<Internal>,
-{
- fn from(err: E) -> Self {
- Self::Failure(err.into())
+ Ok(identity.login)
}
}
diff --git a/src/login/id.rs b/src/login/id.rs
new file mode 100644
index 0000000..c46d697
--- /dev/null
+++ b/src/login/id.rs
@@ -0,0 +1,24 @@
+use crate::id::Id as BaseId;
+
+// Stable identifier for a [Login]. Prefixed with `L`.
+#[derive(Clone, Debug, Eq, PartialEq, sqlx::Type, serde::Serialize)]
+#[sqlx(transparent)]
+pub struct Id(BaseId);
+
+impl From<BaseId> for Id {
+ fn from(id: BaseId) -> Self {
+ Self(id)
+ }
+}
+
+impl Id {
+ pub fn generate() -> Self {
+ BaseId::generate("L")
+ }
+}
+
+impl std::fmt::Display for Id {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
diff --git a/src/login/mod.rs b/src/login/mod.rs
index 6ae82ac..65e3ada 100644
--- a/src/login/mod.rs
+++ b/src/login/mod.rs
@@ -1,8 +1,21 @@
-pub use self::routes::router;
-
pub mod app;
-pub mod broadcaster;
pub mod extract;
-mod repo;
+mod id;
+pub mod password;
+pub mod repo;
mod routes;
-pub mod types;
+
+pub use self::{id::Id, password::Password, routes::router};
+
+// This also implements FromRequestParts (see `./extract.rs`). As a result, it
+// can be used as an extractor for endpoints that want to require login, or for
+// endpoints that need to behave differently depending on whether the client is
+// or is not logged in.
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Login {
+ pub id: Id,
+ pub name: String,
+ // The omission of the hashed password is deliberate, to minimize the
+ // chance that it ends up tangled up in debug output or in some other chunk
+ // of logic elsewhere.
+}
diff --git a/src/password.rs b/src/login/password.rs
index da3930f..da3930f 100644
--- a/src/password.rs
+++ b/src/login/password.rs
diff --git a/src/login/repo.rs b/src/login/repo.rs
new file mode 100644
index 0000000..d1a02c4
--- /dev/null
+++ b/src/login/repo.rs
@@ -0,0 +1,50 @@
+use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
+
+use crate::login::{password::StoredHash, Id, Login};
+
+pub trait Provider {
+ fn logins(&mut self) -> Logins;
+}
+
+impl<'c> Provider for Transaction<'c, Sqlite> {
+ fn logins(&mut self) -> Logins {
+ Logins(self)
+ }
+}
+
+pub struct Logins<'t>(&'t mut SqliteConnection);
+
+impl<'c> Logins<'c> {
+ pub async fn create(
+ &mut self,
+ name: &str,
+ password_hash: &StoredHash,
+ ) -> Result<Login, sqlx::Error> {
+ let id = Id::generate();
+
+ let login = sqlx::query_as!(
+ Login,
+ r#"
+ insert or fail
+ into login (id, name, password_hash)
+ values ($1, $2, $3)
+ returning
+ id as "id: Id",
+ name
+ "#,
+ id,
+ name,
+ password_hash,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(login)
+ }
+}
+
+impl<'t> From<&'t mut SqliteConnection> for Logins<'t> {
+ fn from(tx: &'t mut SqliteConnection) -> Self {
+ Self(tx)
+ }
+}
diff --git a/src/login/repo/mod.rs b/src/login/repo/mod.rs
deleted file mode 100644
index 0e4a05d..0000000
--- a/src/login/repo/mod.rs
+++ /dev/null
@@ -1 +0,0 @@
-pub mod auth;
diff --git a/src/login/routes.rs b/src/login/routes.rs
index 8d9e938..0874cc3 100644
--- a/src/login/routes.rs
+++ b/src/login/routes.rs
@@ -7,11 +7,13 @@ use axum::{
};
use crate::{
- app::App, clock::RequestedAt, error::Internal, password::Password, repo::login::Login,
+ app::App,
+ clock::RequestedAt,
+ error::{Internal, Unauthorized},
+ login::{Login, Password},
+ token::{app, extract::IdentityToken},
};
-use super::{app, extract::IdentityToken};
-
#[cfg(test)]
mod test;
@@ -22,13 +24,18 @@ pub fn router() -> Router<App> {
.route("/api/auth/logout", post(on_logout))
}
-async fn boot(login: Login) -> Boot {
- Boot { login }
+async fn boot(State(app): State<App>, login: Login) -> Result<Boot, Internal> {
+ let resume_point = app.logins().boot_point().await?;
+ Ok(Boot {
+ login,
+ resume_point: resume_point.to_string(),
+ })
}
#[derive(serde::Serialize)]
struct Boot {
login: Login,
+ resume_point: String,
}
impl IntoResponse for Boot {
@@ -50,7 +57,7 @@ async fn on_login(
Json(request): Json<LoginRequest>,
) -> Result<(IdentityToken, StatusCode), LoginError> {
let token = app
- .logins()
+ .tokens()
.login(&request.name, &request.password, &now)
.await
.map_err(LoginError)?;
@@ -66,6 +73,7 @@ impl IntoResponse for LoginError {
let Self(error) = self;
match error {
app::LoginError::Rejected => {
+ // not error::Unauthorized due to differing messaging
(StatusCode::UNAUTHORIZED, "invalid name or password").into_response()
}
other => Internal::from(other).into_response(),
@@ -85,8 +93,8 @@ async fn on_logout(
Json(LogoutRequest {}): Json<LogoutRequest>,
) -> Result<(IdentityToken, StatusCode), LogoutError> {
if let Some(secret) = identity.secret() {
- let (token, _) = app.logins().validate(&secret, &now).await?;
- app.logins().logout(&token).await?;
+ let (token, _) = app.tokens().validate(&secret, &now).await?;
+ app.tokens().logout(&token).await?;
}
let identity = identity.clear();
@@ -103,9 +111,7 @@ enum LogoutError {
impl IntoResponse for LogoutError {
fn into_response(self) -> Response {
match self {
- error @ Self::ValidateError(app::ValidateError::InvalidToken) => {
- (StatusCode::UNAUTHORIZED, error.to_string()).into_response()
- }
+ Self::ValidateError(app::ValidateError::InvalidToken) => Unauthorized.into_response(),
other => Internal::from(other).into_response(),
}
}
diff --git a/src/login/routes/test/boot.rs b/src/login/routes/test/boot.rs
index dee554f..9655354 100644
--- a/src/login/routes/test/boot.rs
+++ b/src/login/routes/test/boot.rs
@@ -1,9 +1,14 @@
+use axum::extract::State;
+
use crate::{login::routes, test::fixtures};
#[tokio::test]
async fn returns_identity() {
+ let app = fixtures::scratch_app().await;
let login = fixtures::login::fictitious();
- let response = routes::boot(login.clone()).await;
+ let response = routes::boot(State(app), login.clone())
+ .await
+ .expect("boot always succeeds");
assert_eq!(login, response.login);
}
diff --git a/src/login/routes/test/login.rs b/src/login/routes/test/login.rs
index 81653ff..3c82738 100644
--- a/src/login/routes/test/login.rs
+++ b/src/login/routes/test/login.rs
@@ -3,10 +3,7 @@ use axum::{
http::StatusCode,
};
-use crate::{
- login::{app, routes},
- test::fixtures,
-};
+use crate::{login::routes, test::fixtures, token::app};
#[tokio::test]
async fn new_identity() {
@@ -37,7 +34,7 @@ async fn new_identity() {
let validated_at = fixtures::now();
let (_, validated) = app
- .logins()
+ .tokens()
.validate(&secret, &validated_at)
.await
.expect("identity secret is valid");
@@ -74,7 +71,7 @@ async fn existing_identity() {
let validated_at = fixtures::now();
let (_, validated_login) = app
- .logins()
+ .tokens()
.validate(&secret, &validated_at)
.await
.expect("identity secret is valid");
@@ -127,14 +124,14 @@ async fn token_expires() {
// Verify the semantics
let expired_at = fixtures::now();
- app.logins()
+ app.tokens()
.expire(&expired_at)
.await
.expect("expiring tokens never fails");
let verified_at = fixtures::now();
let error = app
- .logins()
+ .tokens()
.validate(&secret, &verified_at)
.await
.expect_err("validating an expired token");
diff --git a/src/login/routes/test/logout.rs b/src/login/routes/test/logout.rs
index 20b0d55..42b2534 100644
--- a/src/login/routes/test/logout.rs
+++ b/src/login/routes/test/logout.rs
@@ -3,10 +3,7 @@ use axum::{
http::StatusCode,
};
-use crate::{
- login::{app, routes},
- test::fixtures,
-};
+use crate::{login::routes, test::fixtures, token::app};
#[tokio::test]
async fn successful() {
@@ -37,7 +34,7 @@ async fn successful() {
// Verify the semantics
let error = app
- .logins()
+ .tokens()
.validate(&secret, &now)
.await
.expect_err("secret is invalid");
diff --git a/src/message/app.rs b/src/message/app.rs
new file mode 100644
index 0000000..33ea8ad
--- /dev/null
+++ b/src/message/app.rs
@@ -0,0 +1,115 @@
+use chrono::TimeDelta;
+use itertools::Itertools;
+use sqlx::sqlite::SqlitePool;
+
+use super::{repo::Provider as _, Id, Message};
+use crate::{
+ channel::{self, repo::Provider as _},
+ clock::DateTime,
+ db::NotFound as _,
+ event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence},
+ login::Login,
+};
+
+pub struct Messages<'a> {
+ db: &'a SqlitePool,
+ events: &'a Broadcaster,
+}
+
+impl<'a> Messages<'a> {
+ pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self {
+ Self { db, events }
+ }
+
+ pub async fn send(
+ &self,
+ channel: &channel::Id,
+ sender: &Login,
+ sent_at: &DateTime,
+ body: &str,
+ ) -> Result<Message, SendError> {
+ let mut tx = self.db.begin().await?;
+ let channel = tx
+ .channels()
+ .by_id(channel)
+ .await
+ .not_found(|| SendError::ChannelNotFound(channel.clone()))?;
+ let sent = tx.sequence().next(sent_at).await?;
+ let message = tx
+ .messages()
+ .create(&channel.snapshot(), sender, &sent, body)
+ .await?;
+ tx.commit().await?;
+
+ self.events
+ .broadcast(message.events().map(Event::from).collect::<Vec<_>>());
+
+ Ok(message.snapshot())
+ }
+
+ pub async fn delete(&self, message: &Id, deleted_at: &DateTime) -> Result<(), DeleteError> {
+ let mut tx = self.db.begin().await?;
+ let deleted = tx.sequence().next(deleted_at).await?;
+ let message = tx.messages().delete(message, &deleted).await?;
+ tx.commit().await?;
+
+ self.events.broadcast(
+ message
+ .events()
+ .filter(Sequence::start_from(deleted.sequence))
+ .map(Event::from)
+ .collect::<Vec<_>>(),
+ );
+
+ Ok(())
+ }
+
+ pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> {
+ // Somewhat arbitrarily, expire after 90 days.
+ let expire_at = relative_to.to_owned() - TimeDelta::days(90);
+
+ let mut tx = self.db.begin().await?;
+
+ let expired = tx.messages().expired(&expire_at).await?;
+ let mut events = Vec::with_capacity(expired.len());
+ for message in expired {
+ let deleted = tx.sequence().next(relative_to).await?;
+ let message = tx.messages().delete(&message, &deleted).await?;
+ events.push(
+ message
+ .events()
+ .filter(Sequence::start_from(deleted.sequence)),
+ );
+ }
+
+ tx.commit().await?;
+
+ self.events.broadcast(
+ events
+ .into_iter()
+ .kmerge_by(Sequence::merge)
+ .map(Event::from)
+ .collect::<Vec<_>>(),
+ );
+
+ Ok(())
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum SendError {
+ #[error("channel {0} not found")]
+ ChannelNotFound(channel::Id),
+ #[error(transparent)]
+ DatabaseError(#[from] sqlx::Error),
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum DeleteError {
+ #[error("channel {0} not found")]
+ ChannelNotFound(channel::Id),
+ #[error("message {0} not found")]
+ NotFound(Id),
+ #[error(transparent)]
+ DatabaseError(#[from] sqlx::Error),
+}
diff --git a/src/message/event.rs b/src/message/event.rs
new file mode 100644
index 0000000..66db9b0
--- /dev/null
+++ b/src/message/event.rs
@@ -0,0 +1,71 @@
+use super::{snapshot::Message, Id};
+use crate::{
+ channel::Channel,
+ event::{Instant, Sequenced},
+};
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Event {
+ #[serde(flatten)]
+ pub kind: Kind,
+}
+
+impl Sequenced for Event {
+ fn instant(&self) -> Instant {
+ self.kind.instant()
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum Kind {
+ Sent(Sent),
+ Deleted(Deleted),
+}
+
+impl Sequenced for Kind {
+ fn instant(&self) -> Instant {
+ match self {
+ Self::Sent(sent) => sent.instant(),
+ Self::Deleted(deleted) => deleted.instant(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Sent {
+ #[serde(flatten)]
+ pub message: Message,
+}
+
+impl Sequenced for Sent {
+ fn instant(&self) -> Instant {
+ self.message.sent
+ }
+}
+
+impl From<Sent> for Kind {
+ fn from(event: Sent) -> Self {
+ Self::Sent(event)
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+pub struct Deleted {
+ #[serde(flatten)]
+ pub instant: Instant,
+ pub channel: Channel,
+ pub message: Id,
+}
+
+impl Sequenced for Deleted {
+ fn instant(&self) -> Instant {
+ self.instant
+ }
+}
+
+impl From<Deleted> for Kind {
+ fn from(event: Deleted) -> Self {
+ Self::Deleted(event)
+ }
+}
diff --git a/src/message/history.rs b/src/message/history.rs
new file mode 100644
index 0000000..89fc6b1
--- /dev/null
+++ b/src/message/history.rs
@@ -0,0 +1,41 @@
+use super::{
+ event::{Deleted, Event, Sent},
+ Message,
+};
+use crate::event::Instant;
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct History {
+ pub message: Message,
+ pub deleted: Option<Instant>,
+}
+
+impl History {
+ fn sent(&self) -> Event {
+ Event {
+ kind: Sent {
+ message: self.message.clone(),
+ }
+ .into(),
+ }
+ }
+
+ fn deleted(&self) -> Option<Event> {
+ self.deleted.map(|instant| Event {
+ kind: Deleted {
+ instant,
+ channel: self.message.channel.clone(),
+ message: self.message.id.clone(),
+ }
+ .into(),
+ })
+ }
+
+ pub fn events(&self) -> impl Iterator<Item = Event> {
+ [self.sent()].into_iter().chain(self.deleted())
+ }
+
+ pub fn snapshot(&self) -> Message {
+ self.message.clone()
+ }
+}
diff --git a/src/repo/message.rs b/src/message/id.rs
index a1f73d5..385b103 100644
--- a/src/repo/message.rs
+++ b/src/message/id.rs
@@ -25,9 +25,3 @@ impl fmt::Display for Id {
self.0.fmt(f)
}
}
-
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct Message {
- pub id: Id,
- pub body: String,
-}
diff --git a/src/message/mod.rs b/src/message/mod.rs
new file mode 100644
index 0000000..a8f51ab
--- /dev/null
+++ b/src/message/mod.rs
@@ -0,0 +1,9 @@
+pub mod app;
+pub mod event;
+mod history;
+mod id;
+pub mod repo;
+mod routes;
+mod snapshot;
+
+pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Message};
diff --git a/src/message/repo.rs b/src/message/repo.rs
new file mode 100644
index 0000000..fc835c8
--- /dev/null
+++ b/src/message/repo.rs
@@ -0,0 +1,247 @@
+use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
+
+use super::{snapshot::Message, History, Id};
+use crate::{
+ channel::{self, Channel},
+ clock::DateTime,
+ event::{Instant, Sequence},
+ login::{self, Login},
+};
+
+pub trait Provider {
+ fn messages(&mut self) -> Messages;
+}
+
+impl<'c> Provider for Transaction<'c, Sqlite> {
+ fn messages(&mut self) -> Messages {
+ Messages(self)
+ }
+}
+
+pub struct Messages<'t>(&'t mut SqliteConnection);
+
+impl<'c> Messages<'c> {
+ pub async fn create(
+ &mut self,
+ channel: &Channel,
+ sender: &Login,
+ sent: &Instant,
+ body: &str,
+ ) -> Result<History, sqlx::Error> {
+ let id = Id::generate();
+
+ let message = sqlx::query!(
+ r#"
+ insert into message
+ (id, channel, sender, sent_at, sent_sequence, body)
+ values ($1, $2, $3, $4, $5, $6)
+ returning
+ id as "id: Id",
+ body
+ "#,
+ id,
+ channel.id,
+ sender.id,
+ sent.at,
+ sent.sequence,
+ body,
+ )
+ .map(|row| History {
+ message: Message {
+ sent: *sent,
+ channel: channel.clone(),
+ sender: sender.clone(),
+ id: row.id,
+ body: row.body,
+ },
+ deleted: None,
+ })
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(message)
+ }
+
+ pub async fn in_channel(
+ &mut self,
+ channel: &Channel,
+ resume_at: Option<Sequence>,
+ ) -> Result<Vec<History>, sqlx::Error> {
+ let messages = sqlx::query!(
+ r#"
+ select
+ channel.id as "channel_id: channel::Id",
+ channel.name as "channel_name",
+ sender.id as "sender_id: login::Id",
+ sender.name as "sender_name",
+ message.id as "id: Id",
+ message.body,
+ sent_at as "sent_at: DateTime",
+ sent_sequence as "sent_sequence: Sequence"
+ from message
+ join channel on message.channel = channel.id
+ join login as sender on message.sender = sender.id
+ where channel.id = $1
+ and coalesce(message.sent_sequence <= $2, true)
+ order by message.sent_sequence
+ "#,
+ channel.id,
+ resume_at,
+ )
+ .map(|row| History {
+ message: Message {
+ sent: Instant {
+ at: row.sent_at,
+ sequence: row.sent_sequence,
+ },
+ channel: Channel {
+ id: row.channel_id,
+ name: row.channel_name,
+ },
+ sender: Login {
+ id: row.sender_id,
+ name: row.sender_name,
+ },
+ id: row.id,
+ body: row.body,
+ },
+ deleted: None,
+ })
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ Ok(messages)
+ }
+
+ async fn by_id(&mut self, message: &Id) -> Result<History, sqlx::Error> {
+ let message = sqlx::query!(
+ r#"
+ select
+ channel.id as "channel_id: channel::Id",
+ channel.name as "channel_name",
+ sender.id as "sender_id: login::Id",
+ sender.name as "sender_name",
+ message.id as "id: Id",
+ message.body,
+ sent_at as "sent_at: DateTime",
+ sent_sequence as "sent_sequence: Sequence"
+ from message
+ join channel on message.channel = channel.id
+ join login as sender on message.sender = sender.id
+ where message.id = $1
+ "#,
+ message,
+ )
+ .map(|row| History {
+ message: Message {
+ sent: Instant {
+ at: row.sent_at,
+ sequence: row.sent_sequence,
+ },
+ channel: Channel {
+ id: row.channel_id,
+ name: row.channel_name,
+ },
+ sender: Login {
+ id: row.sender_id,
+ name: row.sender_name,
+ },
+ id: row.id,
+ body: row.body,
+ },
+ deleted: None,
+ })
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(message)
+ }
+
+ pub async fn delete(
+ &mut self,
+ message: &Id,
+ deleted: &Instant,
+ ) -> Result<History, sqlx::Error> {
+ let history = self.by_id(message).await?;
+
+ sqlx::query_scalar!(
+ r#"
+ delete from message
+ where
+ id = $1
+ returning 1 as "deleted: i64"
+ "#,
+ history.message.id,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(History {
+ deleted: Some(*deleted),
+ ..history
+ })
+ }
+
+ pub async fn expired(&mut self, expire_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> {
+ let messages = sqlx::query_scalar!(
+ r#"
+ select
+ id as "message: Id"
+ from message
+ where sent_at < $1
+ "#,
+ expire_at,
+ )
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ Ok(messages)
+ }
+
+ pub async fn replay(
+ &mut self,
+ resume_at: Option<Sequence>,
+ ) -> Result<Vec<History>, sqlx::Error> {
+ let messages = sqlx::query!(
+ r#"
+ select
+ channel.id as "channel_id: channel::Id",
+ channel.name as "channel_name",
+ sender.id as "sender_id: login::Id",
+ sender.name as "sender_name",
+ message.id as "id: Id",
+ message.body,
+ sent_at as "sent_at: DateTime",
+ sent_sequence as "sent_sequence: Sequence"
+ from message
+ join channel on message.channel = channel.id
+ join login as sender on message.sender = sender.id
+ where coalesce(message.sent_sequence > $1, true)
+ "#,
+ resume_at,
+ )
+ .map(|row| History {
+ message: Message {
+ sent: Instant {
+ at: row.sent_at,
+ sequence: row.sent_sequence,
+ },
+ channel: Channel {
+ id: row.channel_id,
+ name: row.channel_name,
+ },
+ sender: Login {
+ id: row.sender_id,
+ name: row.sender_name,
+ },
+ id: row.id,
+ body: row.body,
+ },
+ deleted: None,
+ })
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ Ok(messages)
+ }
+}
diff --git a/src/message/routes.rs b/src/message/routes.rs
new file mode 100644
index 0000000..29fe3d7
--- /dev/null
+++ b/src/message/routes.rs
@@ -0,0 +1,46 @@
+use axum::{
+ extract::{Path, State},
+ http::StatusCode,
+ response::{IntoResponse, Response},
+ routing::delete,
+ Router,
+};
+
+use crate::{
+ app::App,
+ clock::RequestedAt,
+ error::Internal,
+ login::Login,
+ message::{self, app::DeleteError},
+};
+
+pub fn router() -> Router<App> {
+ Router::new().route("/api/messages/:message", delete(on_delete))
+}
+
+async fn on_delete(
+ State(app): State<App>,
+ Path(message): Path<message::Id>,
+ RequestedAt(deleted_at): RequestedAt,
+ _: Login,
+) -> Result<StatusCode, ErrorResponse> {
+ app.messages().delete(&message, &deleted_at).await?;
+
+ Ok(StatusCode::ACCEPTED)
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+struct ErrorResponse(#[from] DeleteError);
+
+impl IntoResponse for ErrorResponse {
+ fn into_response(self) -> Response {
+ let Self(error) = self;
+ match error {
+ not_found @ (DeleteError::ChannelNotFound(_) | DeleteError::NotFound(_)) => {
+ (StatusCode::NOT_FOUND, not_found.to_string()).into_response()
+ }
+ other => Internal::from(other).into_response(),
+ }
+ }
+}
diff --git a/src/message/snapshot.rs b/src/message/snapshot.rs
new file mode 100644
index 0000000..522c1aa
--- /dev/null
+++ b/src/message/snapshot.rs
@@ -0,0 +1,76 @@
+use super::{
+ event::{Event, Kind, Sent},
+ Id,
+};
+use crate::{channel::Channel, event::Instant, login::Login};
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
+#[serde(into = "self::serialize::Message")]
+pub struct Message {
+ #[serde(skip)]
+ pub sent: Instant,
+ pub channel: Channel,
+ pub sender: Login,
+ pub id: Id,
+ pub body: String,
+}
+
+mod serialize {
+ use crate::{channel::Channel, login::Login, message::Id};
+
+ #[derive(serde::Serialize)]
+ pub struct Message {
+ channel: Channel,
+ sender: Login,
+ #[allow(clippy::struct_field_names)]
+ // Deliberately redundant with the module path; this produces a specific serialization.
+ message: MessageData,
+ }
+
+ #[derive(serde::Serialize)]
+ pub struct MessageData {
+ id: Id,
+ body: String,
+ }
+
+ impl From<super::Message> for Message {
+ fn from(message: super::Message) -> Self {
+ Self {
+ channel: message.channel,
+ sender: message.sender,
+ message: MessageData {
+ id: message.id,
+ body: message.body,
+ },
+ }
+ }
+ }
+}
+
+impl Message {
+ fn apply(state: Option<Self>, event: Event) -> Option<Self> {
+ match (state, event.kind) {
+ (None, Kind::Sent(event)) => Some(event.into()),
+ (Some(message), Kind::Deleted(event)) if message.id == event.message => None,
+ (state, event) => panic!("invalid message event {event:#?} for state {state:#?}"),
+ }
+ }
+}
+
+impl FromIterator<Event> for Option<Message> {
+ fn from_iter<I: IntoIterator<Item = Event>>(events: I) -> Self {
+ events.into_iter().fold(None, Message::apply)
+ }
+}
+
+impl From<&Sent> for Message {
+ fn from(event: &Sent) -> Self {
+ event.message.clone()
+ }
+}
+
+impl From<Sent> for Message {
+ fn from(event: Sent) -> Self {
+ event.message
+ }
+}
diff --git a/src/repo/channel.rs b/src/repo/channel.rs
deleted file mode 100644
index 3c7468f..0000000
--- a/src/repo/channel.rs
+++ /dev/null
@@ -1,179 +0,0 @@
-use std::fmt;
-
-use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
-
-use crate::{
- clock::DateTime,
- events::types::{self, Sequence},
- id::Id as BaseId,
-};
-
-pub trait Provider {
- fn channels(&mut self) -> Channels;
-}
-
-impl<'c> Provider for Transaction<'c, Sqlite> {
- fn channels(&mut self) -> Channels {
- Channels(self)
- }
-}
-
-pub struct Channels<'t>(&'t mut SqliteConnection);
-
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct Channel {
- pub id: Id,
- pub name: String,
- #[serde(skip)]
- pub created_at: DateTime,
-}
-
-impl<'c> Channels<'c> {
- pub async fn create(
- &mut self,
- name: &str,
- created_at: &DateTime,
- ) -> Result<Channel, sqlx::Error> {
- let id = Id::generate();
- let sequence = Sequence::default();
-
- let channel = sqlx::query_as!(
- Channel,
- r#"
- insert
- into channel (id, name, created_at, last_sequence)
- values ($1, $2, $3, $4)
- returning
- id as "id: Id",
- name,
- created_at as "created_at: DateTime"
- "#,
- id,
- name,
- created_at,
- sequence,
- )
- .fetch_one(&mut *self.0)
- .await?;
-
- Ok(channel)
- }
-
- pub async fn by_id(&mut self, channel: &Id) -> Result<Channel, sqlx::Error> {
- let channel = sqlx::query_as!(
- Channel,
- r#"
- select
- id as "id: Id",
- name,
- created_at as "created_at: DateTime"
- from channel
- where id = $1
- "#,
- channel,
- )
- .fetch_one(&mut *self.0)
- .await?;
-
- Ok(channel)
- }
-
- pub async fn all(&mut self) -> Result<Vec<Channel>, sqlx::Error> {
- let channels = sqlx::query_as!(
- Channel,
- r#"
- select
- id as "id: Id",
- name,
- created_at as "created_at: DateTime"
- from channel
- order by channel.name
- "#,
- )
- .fetch_all(&mut *self.0)
- .await?;
-
- Ok(channels)
- }
-
- pub async fn delete_expired(
- &mut self,
- channel: &Channel,
- sequence: Sequence,
- deleted_at: &DateTime,
- ) -> Result<types::ChannelEvent, sqlx::Error> {
- let channel = channel.id.clone();
- sqlx::query_scalar!(
- r#"
- delete from channel
- where id = $1
- returning 1 as "row: i64"
- "#,
- channel,
- )
- .fetch_one(&mut *self.0)
- .await?;
-
- Ok(types::ChannelEvent {
- sequence,
- at: *deleted_at,
- data: types::DeletedEvent { channel }.into(),
- })
- }
-
- pub async fn expired(&mut self, expired_at: &DateTime) -> Result<Vec<Channel>, sqlx::Error> {
- let channels = sqlx::query_as!(
- Channel,
- r#"
- select
- channel.id as "id: Id",
- channel.name,
- channel.created_at as "created_at: DateTime"
- from channel
- left join message
- where created_at < $1
- and message.id is null
- "#,
- expired_at,
- )
- .fetch_all(&mut *self.0)
- .await?;
-
- Ok(channels)
- }
-}
-
-// Stable identifier for a [Channel]. Prefixed with `C`.
-#[derive(
- Clone,
- Debug,
- Eq,
- Hash,
- Ord,
- PartialEq,
- PartialOrd,
- sqlx::Type,
- serde::Deserialize,
- serde::Serialize,
-)]
-#[sqlx(transparent)]
-#[serde(transparent)]
-pub struct Id(BaseId);
-
-impl From<BaseId> for Id {
- fn from(id: BaseId) -> Self {
- Self(id)
- }
-}
-
-impl Id {
- pub fn generate() -> Self {
- BaseId::generate("C")
- }
-}
-
-impl fmt::Display for Id {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- self.0.fmt(f)
- }
-}
diff --git a/src/repo/error.rs b/src/repo/error.rs
deleted file mode 100644
index a5961e2..0000000
--- a/src/repo/error.rs
+++ /dev/null
@@ -1,23 +0,0 @@
-pub trait NotFound {
- type Ok;
- fn not_found<E, F>(self, map: F) -> Result<Self::Ok, E>
- where
- E: From<sqlx::Error>,
- F: FnOnce() -> E;
-}
-
-impl<T> NotFound for Result<T, sqlx::Error> {
- type Ok = T;
-
- fn not_found<E, F>(self, map: F) -> Result<T, E>
- where
- E: From<sqlx::Error>,
- F: FnOnce() -> E,
- {
- match self {
- Err(sqlx::Error::RowNotFound) => Err(map()),
- Err(other) => Err(other.into()),
- Ok(value) => Ok(value),
- }
- }
-}
diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs
deleted file mode 100644
index ab61106..0000000
--- a/src/repo/login/extract.rs
+++ /dev/null
@@ -1,15 +0,0 @@
-use axum::{extract::FromRequestParts, http::request::Parts};
-
-use super::Login;
-use crate::{app::App, login::extract::Identity};
-
-#[async_trait::async_trait]
-impl FromRequestParts<App> for Login {
- type Rejection = <Identity as FromRequestParts<App>>::Rejection;
-
- async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> {
- let identity = Identity::from_request_parts(parts, state).await?;
-
- Ok(identity.login)
- }
-}
diff --git a/src/repo/login/mod.rs b/src/repo/login/mod.rs
deleted file mode 100644
index a1b4c6f..0000000
--- a/src/repo/login/mod.rs
+++ /dev/null
@@ -1,4 +0,0 @@
-mod extract;
-mod store;
-
-pub use self::store::{Id, Login, Provider};
diff --git a/src/repo/login/store.rs b/src/repo/login/store.rs
deleted file mode 100644
index b485941..0000000
--- a/src/repo/login/store.rs
+++ /dev/null
@@ -1,86 +0,0 @@
-use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
-
-use crate::{id::Id as BaseId, password::StoredHash};
-
-pub trait Provider {
- fn logins(&mut self) -> Logins;
-}
-
-impl<'c> Provider for Transaction<'c, Sqlite> {
- fn logins(&mut self) -> Logins {
- Logins(self)
- }
-}
-
-pub struct Logins<'t>(&'t mut SqliteConnection);
-
-// This also implements FromRequestParts (see `./extract.rs`). As a result, it
-// can be used as an extractor for endpoints that want to require login, or for
-// endpoints that need to behave differently depending on whether the client is
-// or is not logged in.
-#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
-pub struct Login {
- pub id: Id,
- pub name: String,
- // The omission of the hashed password is deliberate, to minimize the
- // chance that it ends up tangled up in debug output or in some other chunk
- // of logic elsewhere.
-}
-
-impl<'c> Logins<'c> {
- pub async fn create(
- &mut self,
- name: &str,
- password_hash: &StoredHash,
- ) -> Result<Login, sqlx::Error> {
- let id = Id::generate();
-
- let login = sqlx::query_as!(
- Login,
- r#"
- insert or fail
- into login (id, name, password_hash)
- values ($1, $2, $3)
- returning
- id as "id: Id",
- name
- "#,
- id,
- name,
- password_hash,
- )
- .fetch_one(&mut *self.0)
- .await?;
-
- Ok(login)
- }
-}
-
-impl<'t> From<&'t mut SqliteConnection> for Logins<'t> {
- fn from(tx: &'t mut SqliteConnection) -> Self {
- Self(tx)
- }
-}
-
-// Stable identifier for a [Login]. Prefixed with `L`.
-#[derive(Clone, Debug, Eq, PartialEq, sqlx::Type, serde::Serialize)]
-#[sqlx(transparent)]
-pub struct Id(BaseId);
-
-impl From<BaseId> for Id {
- fn from(id: BaseId) -> Self {
- Self(id)
- }
-}
-
-impl Id {
- pub fn generate() -> Self {
- BaseId::generate("L")
- }
-}
-
-impl std::fmt::Display for Id {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- self.0.fmt(f)
- }
-}
diff --git a/src/repo/mod.rs b/src/repo/mod.rs
deleted file mode 100644
index cb9d7c8..0000000
--- a/src/repo/mod.rs
+++ /dev/null
@@ -1,6 +0,0 @@
-pub mod channel;
-pub mod error;
-pub mod login;
-pub mod message;
-pub mod pool;
-pub mod token;
diff --git a/src/test/fixtures/channel.rs b/src/test/fixtures/channel.rs
index 8744470..b678717 100644
--- a/src/test/fixtures/channel.rs
+++ b/src/test/fixtures/channel.rs
@@ -4,7 +4,7 @@ use faker_rand::{
};
use rand;
-use crate::{app::App, clock::RequestedAt, repo::channel::Channel};
+use crate::{app::App, channel::Channel, clock::RequestedAt};
pub async fn create(app: &App, created_at: &RequestedAt) -> Channel {
let name = propose();
diff --git a/src/test/fixtures/event.rs b/src/test/fixtures/event.rs
new file mode 100644
index 0000000..09f0490
--- /dev/null
+++ b/src/test/fixtures/event.rs
@@ -0,0 +1,11 @@
+use crate::{
+ event::{Event, Kind},
+ message::Message,
+};
+
+pub fn message_sent(event: &Event, message: &Message) -> bool {
+ matches!(
+ &event.kind,
+ Kind::MessageSent(event) if message == &event.into()
+ )
+}
diff --git a/src/test/fixtures/filter.rs b/src/test/fixtures/filter.rs
index fbebced..6e62aea 100644
--- a/src/test/fixtures/filter.rs
+++ b/src/test/fixtures/filter.rs
@@ -1,15 +1,11 @@
use futures::future;
-use crate::events::types;
+use crate::event::{Event, Kind};
-pub fn messages() -> impl FnMut(&types::ResumableEvent) -> future::Ready<bool> {
- |types::ResumableEvent(_, event)| {
- future::ready(matches!(event.data, types::ChannelEventData::Message(_)))
- }
+pub fn messages() -> impl FnMut(&Event) -> future::Ready<bool> {
+ |event| future::ready(matches!(event.kind, Kind::MessageSent(_)))
}
-pub fn created() -> impl FnMut(&types::ResumableEvent) -> future::Ready<bool> {
- |types::ResumableEvent(_, event)| {
- future::ready(matches!(event.data, types::ChannelEventData::Created(_)))
- }
+pub fn created() -> impl FnMut(&Event) -> future::Ready<bool> {
+ |event| future::ready(matches!(event.kind, Kind::ChannelCreated(_)))
}
diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs
index 633fb8a..56b4ffa 100644
--- a/src/test/fixtures/identity.rs
+++ b/src/test/fixtures/identity.rs
@@ -3,8 +3,11 @@ use uuid::Uuid;
use crate::{
app::App,
clock::RequestedAt,
- login::extract::{Identity, IdentitySecret, IdentityToken},
- password::Password,
+ login::Password,
+ token::{
+ extract::{Identity, IdentityToken},
+ Secret,
+ },
};
pub fn not_logged_in() -> IdentityToken {
@@ -14,7 +17,7 @@ pub fn not_logged_in() -> IdentityToken {
pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt) -> IdentityToken {
let (name, password) = login;
let token = app
- .logins()
+ .tokens()
.login(name, password, now)
.await
.expect("should succeed given known-valid credentials");
@@ -25,7 +28,7 @@ pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt)
pub async fn from_token(app: &App, token: &IdentityToken, issued_at: &RequestedAt) -> Identity {
let secret = token.secret().expect("identity token has a secret");
let (token, login) = app
- .logins()
+ .tokens()
.validate(&secret, issued_at)
.await
.expect("always validates newly-issued secret");
@@ -38,7 +41,7 @@ pub async fn identity(app: &App, login: &(String, Password), issued_at: &Request
from_token(app, &secret, issued_at).await
}
-pub fn secret(identity: &IdentityToken) -> IdentitySecret {
+pub fn secret(identity: &IdentityToken) -> Secret {
identity.secret().expect("identity contained a secret")
}
diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs
index d6a321b..00c2789 100644
--- a/src/test/fixtures/login.rs
+++ b/src/test/fixtures/login.rs
@@ -3,8 +3,7 @@ use uuid::Uuid;
use crate::{
app::App,
- password::Password,
- repo::login::{self, Login},
+ login::{self, Login, Password},
};
pub async fn create_with_password(app: &App) -> (String, Password) {
diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs
index bfca8cd..381b10b 100644
--- a/src/test/fixtures/message.rs
+++ b/src/test/fixtures/message.rs
@@ -1,22 +1,12 @@
use faker_rand::lorem::Paragraphs;
-use crate::{
- app::App,
- clock::RequestedAt,
- events::types,
- repo::{channel::Channel, login::Login},
-};
+use crate::{app::App, channel::Channel, clock::RequestedAt, login::Login, message::Message};
-pub async fn send(
- app: &App,
- login: &Login,
- channel: &Channel,
- sent_at: &RequestedAt,
-) -> types::ChannelEvent {
+pub async fn send(app: &App, channel: &Channel, login: &Login, sent_at: &RequestedAt) -> Message {
let body = propose();
- app.events()
- .send(login, &channel.id, &body, sent_at)
+ app.messages()
+ .send(&channel.id, login, sent_at, &body)
.await
.expect("should succeed if the channel exists")
}
diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs
index d1dd0c3..c5efa9b 100644
--- a/src/test/fixtures/mod.rs
+++ b/src/test/fixtures/mod.rs
@@ -1,8 +1,9 @@
use chrono::{TimeDelta, Utc};
-use crate::{app::App, clock::RequestedAt, repo::pool};
+use crate::{app::App, clock::RequestedAt, db};
pub mod channel;
+pub mod event;
pub mod filter;
pub mod future;
pub mod identity;
@@ -10,7 +11,7 @@ pub mod login;
pub mod message;
pub async fn scratch_app() -> App {
- let pool = pool::prepare("sqlite::memory:")
+ let pool = db::prepare("sqlite::memory:")
.await
.expect("setting up in-memory sqlite database");
App::from(pool)
diff --git a/src/token/app.rs b/src/token/app.rs
new file mode 100644
index 0000000..5c4fcd5
--- /dev/null
+++ b/src/token/app.rs
@@ -0,0 +1,170 @@
+use chrono::TimeDelta;
+use futures::{
+ future,
+ stream::{self, StreamExt as _},
+ Stream,
+};
+use sqlx::sqlite::SqlitePool;
+
+use super::{
+ broadcaster::Broadcaster, event, repo::auth::Provider as _, repo::Provider as _, Id, Secret,
+};
+use crate::{
+ clock::DateTime,
+ db::NotFound as _,
+ login::{repo::Provider as _, Login, Password},
+};
+
+pub struct Tokens<'a> {
+ db: &'a SqlitePool,
+ tokens: &'a Broadcaster,
+}
+
+impl<'a> Tokens<'a> {
+ pub const fn new(db: &'a SqlitePool, tokens: &'a Broadcaster) -> Self {
+ Self { db, tokens }
+ }
+ pub async fn login(
+ &self,
+ name: &str,
+ password: &Password,
+ login_at: &DateTime,
+ ) -> Result<Secret, LoginError> {
+ let mut tx = self.db.begin().await?;
+
+ let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? {
+ if stored_hash.verify(password)? {
+ // Password verified; use the login.
+ login
+ } else {
+ // Password NOT verified.
+ return Err(LoginError::Rejected);
+ }
+ } else {
+ let password_hash = password.hash()?;
+ tx.logins().create(name, &password_hash).await?
+ };
+
+ let token = tx.tokens().issue(&login, login_at).await?;
+ tx.commit().await?;
+
+ Ok(token)
+ }
+
+ pub async fn validate(
+ &self,
+ secret: &Secret,
+ used_at: &DateTime,
+ ) -> Result<(Id, Login), ValidateError> {
+ let mut tx = self.db.begin().await?;
+ let login = tx
+ .tokens()
+ .validate(secret, used_at)
+ .await
+ .not_found(|| ValidateError::InvalidToken)?;
+ tx.commit().await?;
+
+ Ok(login)
+ }
+
+ pub async fn limit_stream<E>(
+ &self,
+ token: Id,
+ events: impl Stream<Item = E> + std::fmt::Debug,
+ ) -> Result<impl Stream<Item = E> + std::fmt::Debug, ValidateError>
+ where
+ E: std::fmt::Debug,
+ {
+ // Subscribe, first.
+ let token_events = self.tokens.subscribe();
+
+ // Check that the token is valid at this point in time, second. If it is, then
+ // any future revocations will appear in the subscription. If not, bail now.
+ //
+ // It's possible, otherwise, to get to this point with a token that _was_ valid
+ // at the start of the request, but which was invalided _before_ the
+ // `subscribe()` call. In that case, the corresponding revocation event will
+ // simply be missed, since the `token_events` stream subscribed after the fact.
+ // This check cancels guarding the stream here.
+ //
+ // Yes, this is a weird niche edge case. Most things don't double-check, because
+ // they aren't expected to run long enough for the token's revocation to
+ // matter. Supervising a stream, on the other hand, will run for a
+ // _long_ time; if we miss the race here, we'll never actually carry out the
+ // supervision.
+ let mut tx = self.db.begin().await?;
+ tx.tokens()
+ .require(&token)
+ .await
+ .not_found(|| ValidateError::InvalidToken)?;
+ tx.commit().await?;
+
+ // Then construct the guarded stream. First, project both streams into
+ // `GuardedEvent`.
+ let token_events = token_events
+ .filter(move |event| future::ready(event.token == token))
+ .map(|_| GuardedEvent::TokenRevoked);
+ let events = events.map(|event| GuardedEvent::Event(event));
+
+ // Merge the two streams, then unproject them, stopping at
+ // `GuardedEvent::TokenRevoked`.
+ let stream = stream::select(token_events, events).scan((), |(), event| {
+ future::ready(match event {
+ GuardedEvent::Event(event) => Some(event),
+ GuardedEvent::TokenRevoked => None,
+ })
+ });
+
+ Ok(stream)
+ }
+
+ pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> {
+ // Somewhat arbitrarily, expire after 7 days.
+ let expire_at = relative_to.to_owned() - TimeDelta::days(7);
+
+ let mut tx = self.db.begin().await?;
+ let tokens = tx.tokens().expire(&expire_at).await?;
+ tx.commit().await?;
+
+ for event in tokens.into_iter().map(event::TokenRevoked::from) {
+ self.tokens.broadcast(event);
+ }
+
+ Ok(())
+ }
+
+ pub async fn logout(&self, token: &Id) -> Result<(), ValidateError> {
+ let mut tx = self.db.begin().await?;
+ tx.tokens().revoke(token).await?;
+ tx.commit().await?;
+
+ self.tokens
+ .broadcast(event::TokenRevoked::from(token.clone()));
+
+ Ok(())
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum LoginError {
+ #[error("invalid login")]
+ Rejected,
+ #[error(transparent)]
+ DatabaseError(#[from] sqlx::Error),
+ #[error(transparent)]
+ PasswordHashError(#[from] password_hash::Error),
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum ValidateError {
+ #[error("invalid token")]
+ InvalidToken,
+ #[error(transparent)]
+ DatabaseError(#[from] sqlx::Error),
+}
+
+#[derive(Debug)]
+enum GuardedEvent<E> {
+ TokenRevoked,
+ Event(E),
+}
diff --git a/src/token/broadcaster.rs b/src/token/broadcaster.rs
new file mode 100644
index 0000000..8e2e006
--- /dev/null
+++ b/src/token/broadcaster.rs
@@ -0,0 +1,4 @@
+use super::event;
+use crate::broadcast;
+
+pub type Broadcaster = broadcast::Broadcaster<event::TokenRevoked>;
diff --git a/src/login/types.rs b/src/token/event.rs
index 7c7cbf9..d53d436 100644
--- a/src/login/types.rs
+++ b/src/token/event.rs
@@ -1,4 +1,4 @@
-use crate::repo::token;
+use crate::token;
#[derive(Clone, Debug)]
pub struct TokenRevoked {
diff --git a/src/token/extract/identity.rs b/src/token/extract/identity.rs
new file mode 100644
index 0000000..60ad220
--- /dev/null
+++ b/src/token/extract/identity.rs
@@ -0,0 +1,75 @@
+use axum::{
+ extract::{FromRequestParts, State},
+ http::request::Parts,
+ response::{IntoResponse, Response},
+};
+
+use super::IdentityToken;
+
+use crate::{
+ app::App,
+ clock::RequestedAt,
+ error::{Internal, Unauthorized},
+ login::Login,
+ token::{self, app::ValidateError},
+};
+
+#[derive(Clone, Debug)]
+pub struct Identity {
+ pub token: token::Id,
+ pub login: Login,
+}
+
+#[async_trait::async_trait]
+impl FromRequestParts<App> for Identity {
+ type Rejection = LoginError<Internal>;
+
+ async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> {
+ // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on
+ // stable), the following can be replaced:
+ //
+ // ```
+ // let Ok(identity_token) = IdentityToken::from_request_parts(
+ // parts,
+ // state,
+ // ).await;
+ // ```
+ let identity_token = IdentityToken::from_request_parts(parts, state).await?;
+ let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?;
+
+ let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?;
+
+ let app = State::<App>::from_request_parts(parts, state).await?;
+ match app.tokens().validate(&secret, &used_at).await {
+ Ok((token, login)) => Ok(Identity { token, login }),
+ Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized),
+ Err(other) => Err(other.into()),
+ }
+ }
+}
+
+pub enum LoginError<E> {
+ Failure(E),
+ Unauthorized,
+}
+
+impl<E> IntoResponse for LoginError<E>
+where
+ E: IntoResponse,
+{
+ fn into_response(self) -> Response {
+ match self {
+ Self::Unauthorized => Unauthorized.into_response(),
+ Self::Failure(e) => e.into_response(),
+ }
+ }
+}
+
+impl<E> From<E> for LoginError<Internal>
+where
+ E: Into<Internal>,
+{
+ fn from(err: E) -> Self {
+ Self::Failure(err.into())
+ }
+}
diff --git a/src/token/extract/identity_token.rs b/src/token/extract/identity_token.rs
new file mode 100644
index 0000000..0a47a43
--- /dev/null
+++ b/src/token/extract/identity_token.rs
@@ -0,0 +1,94 @@
+use std::fmt;
+
+use axum::{
+ extract::FromRequestParts,
+ http::request::Parts,
+ response::{IntoResponseParts, ResponseParts},
+};
+use axum_extra::extract::cookie::{Cookie, CookieJar};
+
+use crate::token::Secret;
+
+// The usage pattern here - receive the extractor as an argument, return it in
+// the response - is heavily modelled after CookieJar's own intended usage.
+#[derive(Clone)]
+pub struct IdentityToken {
+ cookies: CookieJar,
+}
+
+impl fmt::Debug for IdentityToken {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("IdentityToken")
+ .field("identity", &self.secret())
+ .finish()
+ }
+}
+
+impl IdentityToken {
+ // Creates a new, unpopulated identity token store.
+ #[cfg(test)]
+ pub fn new() -> Self {
+ Self {
+ cookies: CookieJar::new(),
+ }
+ }
+
+ // Get the identity secret sent in the request, if any. If the identity
+ // was not sent, or if it has previously been [clear]ed, then this will
+ // return [None]. If the identity has previously been [set], then this
+ // will return that secret, regardless of what the request originally
+ // included.
+ pub fn secret(&self) -> Option<Secret> {
+ self.cookies
+ .get(IDENTITY_COOKIE)
+ .map(Cookie::value)
+ .map(Secret::from)
+ }
+
+ // Positively set the identity secret, and ensure that it will be sent
+ // back to the client when this extractor is included in a response.
+ pub fn set(self, secret: impl Into<Secret>) -> Self {
+ let secret = secret.into().reveal();
+ let identity_cookie = Cookie::build((IDENTITY_COOKIE, secret))
+ .http_only(true)
+ .path("/api/")
+ .permanent()
+ .build();
+
+ Self {
+ cookies: self.cookies.add(identity_cookie),
+ }
+ }
+
+ // Remove the identity secret and ensure that it will be cleared when this
+ // extractor is included in a response.
+ pub fn clear(self) -> Self {
+ Self {
+ cookies: self.cookies.remove(IDENTITY_COOKIE),
+ }
+ }
+}
+
+const IDENTITY_COOKIE: &str = "identity";
+
+#[async_trait::async_trait]
+impl<S> FromRequestParts<S> for IdentityToken
+where
+ S: Send + Sync,
+{
+ type Rejection = <CookieJar as FromRequestParts<S>>::Rejection;
+
+ async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
+ let cookies = CookieJar::from_request_parts(parts, state).await?;
+ Ok(Self { cookies })
+ }
+}
+
+impl IntoResponseParts for IdentityToken {
+ type Error = <CookieJar as IntoResponseParts>::Error;
+
+ fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
+ let Self { cookies } = self;
+ cookies.into_response_parts(res)
+ }
+}
diff --git a/src/token/extract/mod.rs b/src/token/extract/mod.rs
new file mode 100644
index 0000000..b4800ae
--- /dev/null
+++ b/src/token/extract/mod.rs
@@ -0,0 +1,4 @@
+mod identity;
+mod identity_token;
+
+pub use self::{identity::Identity, identity_token::IdentityToken};
diff --git a/src/token/id.rs b/src/token/id.rs
new file mode 100644
index 0000000..9ef063c
--- /dev/null
+++ b/src/token/id.rs
@@ -0,0 +1,27 @@
+use std::fmt;
+
+use crate::id::Id as BaseId;
+
+// Stable identifier for a token. Prefixed with `T`.
+#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)]
+#[sqlx(transparent)]
+#[serde(transparent)]
+pub struct Id(BaseId);
+
+impl From<BaseId> for Id {
+ fn from(id: BaseId) -> Self {
+ Self(id)
+ }
+}
+
+impl Id {
+ pub fn generate() -> Self {
+ BaseId::generate("T")
+ }
+}
+
+impl fmt::Display for Id {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
diff --git a/src/token/mod.rs b/src/token/mod.rs
new file mode 100644
index 0000000..d122611
--- /dev/null
+++ b/src/token/mod.rs
@@ -0,0 +1,9 @@
+pub mod app;
+pub mod broadcaster;
+mod event;
+pub mod extract;
+mod id;
+mod repo;
+mod secret;
+
+pub use self::{id::Id, secret::Secret};
diff --git a/src/login/repo/auth.rs b/src/token/repo/auth.rs
index 3033c8f..b299697 100644
--- a/src/login/repo/auth.rs
+++ b/src/token/repo/auth.rs
@@ -1,9 +1,6 @@
use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
-use crate::{
- password::StoredHash,
- repo::login::{self, Login},
-};
+use crate::login::{self, password::StoredHash, Login};
pub trait Provider {
fn auth(&mut self) -> Auth;
diff --git a/src/token/repo/mod.rs b/src/token/repo/mod.rs
new file mode 100644
index 0000000..9169743
--- /dev/null
+++ b/src/token/repo/mod.rs
@@ -0,0 +1,4 @@
+pub mod auth;
+mod token;
+
+pub use self::token::Provider;
diff --git a/src/repo/token.rs b/src/token/repo/token.rs
index d96c094..5f64dac 100644
--- a/src/repo/token.rs
+++ b/src/token/repo/token.rs
@@ -1,10 +1,11 @@
-use std::fmt;
-
use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
use uuid::Uuid;
-use super::login::{self, Login};
-use crate::{clock::DateTime, id::Id as BaseId, login::extract::IdentitySecret};
+use crate::{
+ clock::DateTime,
+ login::{self, Login},
+ token::{Id, Secret},
+};
pub trait Provider {
fn tokens(&mut self) -> Tokens;
@@ -25,7 +26,7 @@ impl<'c> Tokens<'c> {
&mut self,
login: &Login,
issued_at: &DateTime,
- ) -> Result<IdentitySecret, sqlx::Error> {
+ ) -> Result<Secret, sqlx::Error> {
let id = Id::generate();
let secret = Uuid::new_v4().to_string();
@@ -34,7 +35,7 @@ impl<'c> Tokens<'c> {
insert
into token (id, secret, login, issued_at, last_used_at)
values ($1, $2, $3, $4, $4)
- returning secret as "secret!: IdentitySecret"
+ returning secret as "secret!: Secret"
"#,
id,
secret,
@@ -47,6 +48,21 @@ impl<'c> Tokens<'c> {
Ok(secret)
}
+ pub async fn require(&mut self, token: &Id) -> Result<(), sqlx::Error> {
+ sqlx::query_scalar!(
+ r#"
+ select id as "id: Id"
+ from token
+ where id = $1
+ "#,
+ token,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
// Revoke a token by its secret.
pub async fn revoke(&mut self, token: &Id) -> Result<(), sqlx::Error> {
sqlx::query_scalar!(
@@ -87,7 +103,7 @@ impl<'c> Tokens<'c> {
// timestamp will be set to `used_at`.
pub async fn validate(
&mut self,
- secret: &IdentitySecret,
+ secret: &Secret,
used_at: &DateTime,
) -> Result<(Id, Login), sqlx::Error> {
// I would use `update … returning` to do this in one query, but
@@ -133,27 +149,3 @@ impl<'c> Tokens<'c> {
Ok(login)
}
}
-
-// Stable identifier for a token. Prefixed with `T`.
-#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)]
-#[sqlx(transparent)]
-#[serde(transparent)]
-pub struct Id(BaseId);
-
-impl From<BaseId> for Id {
- fn from(id: BaseId) -> Self {
- Self(id)
- }
-}
-
-impl Id {
- pub fn generate() -> Self {
- BaseId::generate("T")
- }
-}
-
-impl fmt::Display for Id {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- self.0.fmt(f)
- }
-}
diff --git a/src/token/secret.rs b/src/token/secret.rs
new file mode 100644
index 0000000..28c93bb
--- /dev/null
+++ b/src/token/secret.rs
@@ -0,0 +1,27 @@
+use std::fmt;
+
+#[derive(sqlx::Type)]
+#[sqlx(transparent)]
+pub struct Secret(String);
+
+impl Secret {
+ pub fn reveal(self) -> String {
+ let Self(secret) = self;
+ secret
+ }
+}
+
+impl fmt::Debug for Secret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_tuple("IdentityToken").field(&"********").finish()
+ }
+}
+
+impl<S> From<S> for Secret
+where
+ S: Into<String>,
+{
+ fn from(value: S) -> Self {
+ Self(value.into())
+ }
+}