diff options
Diffstat (limited to 'src')
78 files changed, 2384 insertions, 1483 deletions
@@ -2,35 +2,45 @@ use sqlx::sqlite::SqlitePool; use crate::{ channel::app::Channels, - events::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, - login::{app::Logins, broadcaster::Broadcaster as LoginBroadcaster}, + event::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, + login::app::Logins, + message::app::Messages, + token::{app::Tokens, broadcaster::Broadcaster as TokenBroadcaster}, }; #[derive(Clone)] pub struct App { db: SqlitePool, events: EventBroadcaster, - logins: LoginBroadcaster, + tokens: TokenBroadcaster, } impl App { pub fn from(db: SqlitePool) -> Self { let events = EventBroadcaster::default(); - let logins = LoginBroadcaster::default(); - Self { db, events, logins } + let tokens = TokenBroadcaster::default(); + Self { db, events, tokens } } } impl App { - pub const fn logins(&self) -> Logins { - Logins::new(&self.db, &self.logins) + pub const fn channels(&self) -> Channels { + Channels::new(&self.db, &self.events) } pub const fn events(&self) -> Events { Events::new(&self.db, &self.events) } - pub const fn channels(&self) -> Channels { - Channels::new(&self.db, &self.events) + pub const fn logins(&self) -> Logins { + Logins::new(&self.db) + } + + pub const fn messages(&self) -> Messages { + Messages::new(&self.db, &self.events) + } + + pub const fn tokens(&self) -> Tokens { + Tokens::new(&self.db, &self.tokens) } } diff --git a/src/broadcast.rs b/src/broadcast.rs index 083a301..bedc263 100644 --- a/src/broadcast.rs +++ b/src/broadcast.rs @@ -32,7 +32,7 @@ where { // panic: if ``message.channel.id`` has not been previously registered, // and was not part of the initial set of channels. - pub fn broadcast(&self, message: &M) { + pub fn broadcast(&self, message: impl Into<M>) { let tx = self.sender(); // Per the Tokio docs, the returned error is only used to indicate that @@ -42,7 +42,7 @@ where // // The successful return value, which includes the number of active // receivers, also isn't that interesting to us. - let _ = tx.send(message.clone()); + let _ = tx.send(message.into()); } // panic: if ``channel`` has not been previously registered, and was not diff --git a/src/channel/app.rs b/src/channel/app.rs index 70cda47..bb331ec 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -1,10 +1,13 @@ use chrono::TimeDelta; +use itertools::Itertools; use sqlx::sqlite::SqlitePool; +use super::{repo::Provider as _, Channel, Id}; use crate::{ clock::DateTime, - events::{broadcaster::Broadcaster, repo::message::Provider as _, types::ChannelEvent}, - repo::channel::{Channel, Provider as _}, + db::NotFound, + event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence}, + message::{repo::Provider as _, Message}, }; pub struct Channels<'a> { @@ -19,27 +22,108 @@ impl<'a> Channels<'a> { pub async fn create(&self, name: &str, created_at: &DateTime) -> Result<Channel, CreateError> { let mut tx = self.db.begin().await?; + let created = tx.sequence().next(created_at).await?; let channel = tx .channels() - .create(name, created_at) + .create(name, &created) .await .map_err(|err| CreateError::from_duplicate_name(err, name))?; tx.commit().await?; self.events - .broadcast(&ChannelEvent::created(channel.clone())); + .broadcast(channel.events().map(Event::from).collect::<Vec<_>>()); - Ok(channel) + Ok(channel.snapshot()) } - pub async fn all(&self) -> Result<Vec<Channel>, InternalError> { + pub async fn all(&self, resume_point: Option<Sequence>) -> Result<Vec<Channel>, InternalError> { let mut tx = self.db.begin().await?; - let channels = tx.channels().all().await?; + let channels = tx.channels().all(resume_point).await?; tx.commit().await?; + let channels = channels + .into_iter() + .filter_map(|channel| { + channel + .events() + .filter(Sequence::up_to(resume_point)) + .collect() + }) + .collect(); + Ok(channels) } + pub async fn messages( + &self, + channel: &Id, + resume_point: Option<Sequence>, + ) -> Result<Vec<Message>, Error> { + let mut tx = self.db.begin().await?; + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| Error::NotFound(channel.clone()))? + .snapshot(); + + let messages = tx + .messages() + .in_channel(&channel, resume_point) + .await? + .into_iter() + .filter_map(|message| { + message + .events() + .filter(Sequence::up_to(resume_point)) + .collect() + }) + .collect(); + + Ok(messages) + } + + pub async fn delete(&self, channel: &Id, deleted_at: &DateTime) -> Result<(), Error> { + let mut tx = self.db.begin().await?; + + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| Error::NotFound(channel.clone()))? + .snapshot(); + + let mut events = Vec::new(); + + let messages = tx.messages().in_channel(&channel, None).await?; + for message in messages { + let message = message.snapshot(); + let deleted = tx.sequence().next(deleted_at).await?; + let message = tx.messages().delete(&message.id, &deleted).await?; + events.extend( + message + .events() + .filter(Sequence::start_from(deleted.sequence)) + .map(Event::from), + ); + } + + let deleted = tx.sequence().next(deleted_at).await?; + let channel = tx.channels().delete(&channel.id, &deleted).await?; + events.extend( + channel + .events() + .filter(Sequence::start_from(deleted.sequence)) + .map(Event::from), + ); + + tx.commit().await?; + + self.events.broadcast(events); + + Ok(()) + } + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { // Somewhat arbitrarily, expire after 90 days. let expire_at = relative_to.to_owned() - TimeDelta::days(90); @@ -49,19 +133,24 @@ impl<'a> Channels<'a> { let mut events = Vec::with_capacity(expired.len()); for channel in expired { - let sequence = tx.message_events().assign_sequence(&channel).await?; - let event = tx - .channels() - .delete_expired(&channel, sequence, relative_to) - .await?; - events.push(event); + let deleted = tx.sequence().next(relative_to).await?; + let channel = tx.channels().delete(&channel, &deleted).await?; + events.push( + channel + .events() + .filter(Sequence::start_from(deleted.sequence)), + ); } tx.commit().await?; - for event in events { - self.events.broadcast(&event); - } + self.events.broadcast( + events + .into_iter() + .kmerge_by(Sequence::merge) + .map(Event::from) + .collect::<Vec<_>>(), + ); Ok(()) } @@ -75,6 +164,14 @@ pub enum CreateError { DatabaseError(#[from] sqlx::Error), } +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("channel {0} not found")] + NotFound(Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} + impl CreateError { fn from_duplicate_name(error: sqlx::Error, name: &str) -> Self { if let Some(error) = error.as_database_error() { diff --git a/src/channel/event.rs b/src/channel/event.rs new file mode 100644 index 0000000..9c54174 --- /dev/null +++ b/src/channel/event.rs @@ -0,0 +1,48 @@ +use super::Channel; +use crate::{ + channel, + event::{Instant, Sequenced}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Event { + #[serde(flatten)] + pub instant: Instant, + #[serde(flatten)] + pub kind: Kind, +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + self.instant + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Kind { + Created(Created), + Deleted(Deleted), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Created { + pub channel: Channel, +} + +impl From<Created> for Kind { + fn from(event: Created) -> Self { + Self::Created(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Deleted { + pub channel: channel::Id, +} + +impl From<Deleted> for Kind { + fn from(event: Deleted) -> Self { + Self::Deleted(event) + } +} diff --git a/src/channel/history.rs b/src/channel/history.rs new file mode 100644 index 0000000..3cc7d9d --- /dev/null +++ b/src/channel/history.rs @@ -0,0 +1,42 @@ +use super::{ + event::{Created, Deleted, Event}, + Channel, +}; +use crate::event::Instant; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct History { + pub channel: Channel, + pub created: Instant, + pub deleted: Option<Instant>, +} + +impl History { + fn created(&self) -> Event { + Event { + instant: self.created, + kind: Created { + channel: self.channel.clone(), + } + .into(), + } + } + + fn deleted(&self) -> Option<Event> { + self.deleted.map(|instant| Event { + instant, + kind: Deleted { + channel: self.channel.id.clone(), + } + .into(), + }) + } + + pub fn events(&self) -> impl Iterator<Item = Event> { + [self.created()].into_iter().chain(self.deleted()) + } + + pub fn snapshot(&self) -> Channel { + self.channel.clone() + } +} diff --git a/src/channel/id.rs b/src/channel/id.rs new file mode 100644 index 0000000..22a2700 --- /dev/null +++ b/src/channel/id.rs @@ -0,0 +1,38 @@ +use std::fmt; + +use crate::id::Id as BaseId; + +// Stable identifier for a [Channel]. Prefixed with `C`. +#[derive( + Clone, + Debug, + Eq, + Hash, + Ord, + PartialEq, + PartialOrd, + sqlx::Type, + serde::Deserialize, + serde::Serialize, +)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From<BaseId> for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("C") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 9f79dbb..eb8200b 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,4 +1,9 @@ pub mod app; +pub mod event; +mod history; +mod id; +pub mod repo; mod routes; +mod snapshot; -pub use self::routes::router; +pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Channel}; diff --git a/src/channel/repo.rs b/src/channel/repo.rs new file mode 100644 index 0000000..2b48436 --- /dev/null +++ b/src/channel/repo.rs @@ -0,0 +1,202 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + channel::{Channel, History, Id}, + clock::DateTime, + event::{Instant, Sequence}, +}; + +pub trait Provider { + fn channels(&mut self) -> Channels; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn channels(&mut self) -> Channels { + Channels(self) + } +} + +pub struct Channels<'t>(&'t mut SqliteConnection); + +impl<'c> Channels<'c> { + pub async fn create(&mut self, name: &str, created: &Instant) -> Result<History, sqlx::Error> { + let id = Id::generate(); + let channel = sqlx::query!( + r#" + insert + into channel (id, name, created_at, created_sequence) + values ($1, $2, $3, $4) + returning + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + "#, + id, + name, + created.at, + created.sequence, + ) + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + deleted: None, + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(channel) + } + + pub async fn by_id(&mut self, channel: &Id) -> Result<History, sqlx::Error> { + let channel = sqlx::query!( + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where id = $1 + "#, + channel, + ) + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + deleted: None, + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(channel) + } + + pub async fn all(&mut self, resume_at: Option<Sequence>) -> Result<Vec<History>, sqlx::Error> { + let channels = sqlx::query!( + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where coalesce(created_sequence <= $1, true) + order by channel.name + "#, + resume_at, + ) + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + deleted: None, + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } + + pub async fn replay( + &mut self, + resume_at: Option<Sequence>, + ) -> Result<Vec<History>, sqlx::Error> { + let channels = sqlx::query!( + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where coalesce(created_sequence > $1, true) + "#, + resume_at, + ) + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + deleted: None, + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } + + pub async fn delete( + &mut self, + channel: &Id, + deleted: &Instant, + ) -> Result<History, sqlx::Error> { + let channel = sqlx::query!( + r#" + delete from channel + where id = $1 + returning + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + "#, + channel, + ) + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + deleted: Some(*deleted), + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(channel) + } + + pub async fn expired(&mut self, expired_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> { + let channels = sqlx::query_scalar!( + r#" + select + channel.id as "id: Id" + from channel + left join message + where created_at < $1 + and message.id is null + "#, + expired_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } +} diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 1f8db5a..23c0602 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -2,20 +2,19 @@ use axum::{ extract::{Json, Path, State}, http::StatusCode, response::{IntoResponse, Response}, - routing::{get, post}, + routing::{delete, get, post}, Router, }; +use axum_extra::extract::Query; -use super::app; +use super::{app, Channel, Id}; use crate::{ app::App, clock::RequestedAt, error::Internal, - events::app::EventsError, - repo::{ - channel::{self, Channel}, - login::Login, - }, + event::{Instant, Sequence}, + login::Login, + message::{self, app::SendError}, }; #[cfg(test)] @@ -26,10 +25,21 @@ pub fn router() -> Router<App> { .route("/api/channels", get(list)) .route("/api/channels", post(on_create)) .route("/api/channels/:channel", post(on_send)) + .route("/api/channels/:channel", delete(on_delete)) + .route("/api/channels/:channel/messages", get(messages)) } -async fn list(State(app): State<App>, _: Login) -> Result<Channels, Internal> { - let channels = app.channels().all().await?; +#[derive(Default, serde::Deserialize)] +struct ResumeQuery { + resume_point: Option<Sequence>, +} + +async fn list( + State(app): State<App>, + _: Login, + Query(query): Query<ResumeQuery>, +) -> Result<Channels, Internal> { + let channels = app.channels().all(query.resume_point).await?; let response = Channels(channels); Ok(response) @@ -86,31 +96,107 @@ struct SendRequest { async fn on_send( State(app): State<App>, - Path(channel): Path<channel::Id>, + Path(channel): Path<Id>, RequestedAt(sent_at): RequestedAt, login: Login, Json(request): Json<SendRequest>, +) -> Result<StatusCode, SendErrorResponse> { + app.messages() + .send(&channel, &login, &sent_at, &request.message) + .await?; + + Ok(StatusCode::ACCEPTED) +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +struct SendErrorResponse(#[from] SendError); + +impl IntoResponse for SendErrorResponse { + fn into_response(self) -> Response { + let Self(error) = self; + match error { + not_found @ SendError::ChannelNotFound(_) => { + (StatusCode::NOT_FOUND, not_found.to_string()).into_response() + } + other => Internal::from(other).into_response(), + } + } +} + +async fn on_delete( + State(app): State<App>, + Path(channel): Path<Id>, + RequestedAt(deleted_at): RequestedAt, + _: Login, ) -> Result<StatusCode, ErrorResponse> { - app.events() - .send(&login, &channel, &request.message, &sent_at) - .await - // Could impl `From` here, but it's more code and this is used once. - .map_err(ErrorResponse)?; + app.channels().delete(&channel, &deleted_at).await?; Ok(StatusCode::ACCEPTED) } -#[derive(Debug)] -struct ErrorResponse(EventsError); +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +struct ErrorResponse(#[from] app::Error); impl IntoResponse for ErrorResponse { fn into_response(self) -> Response { let Self(error) = self; match error { - not_found @ EventsError::ChannelNotFound(_) => { + not_found @ app::Error::NotFound(_) => { (StatusCode::NOT_FOUND, not_found.to_string()).into_response() } other => Internal::from(other).into_response(), } } } + +async fn messages( + State(app): State<App>, + Path(channel): Path<Id>, + _: Login, + Query(query): Query<ResumeQuery>, +) -> Result<Messages, ErrorResponse> { + let messages = app + .channels() + .messages(&channel, query.resume_point) + .await?; + let response = Messages( + messages + .into_iter() + .map(|message| MessageView { + sent: message.sent, + sender: message.sender, + message: MessageInner { + id: message.id, + body: message.body, + }, + }) + .collect(), + ); + + Ok(response) +} + +struct Messages(Vec<MessageView>); + +#[derive(serde::Serialize)] +struct MessageView { + #[serde(flatten)] + sent: Instant, + sender: Login, + message: MessageInner, +} + +#[derive(serde::Serialize)] +struct MessageInner { + id: message::Id, + body: String, +} + +impl IntoResponse for Messages { + fn into_response(self) -> Response { + let Self(messages) = self; + Json(messages).into_response() + } +} diff --git a/src/channel/routes/test/list.rs b/src/channel/routes/test/list.rs index bc94024..f15a53c 100644 --- a/src/channel/routes/test/list.rs +++ b/src/channel/routes/test/list.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use axum_extra::extract::Query; use crate::{channel::routes, test::fixtures}; @@ -11,7 +12,7 @@ async fn empty_list() { // Call the endpoint - let routes::Channels(channels) = routes::list(State(app), viewer) + let routes::Channels(channels) = routes::list(State(app), viewer, Query::default()) .await .expect("always succeeds"); @@ -30,7 +31,7 @@ async fn one_channel() { // Call the endpoint - let routes::Channels(channels) = routes::list(State(app), viewer) + let routes::Channels(channels) = routes::list(State(app), viewer, Query::default()) .await .expect("always succeeds"); @@ -52,7 +53,7 @@ async fn multiple_channels() { // Call the endpoint - let routes::Channels(response_channels) = routes::list(State(app), viewer) + let routes::Channels(response_channels) = routes::list(State(app), viewer, Query::default()) .await .expect("always succeeds"); diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index e2610a5..5733c9e 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -3,7 +3,7 @@ use futures::stream::StreamExt as _; use crate::{ channel::{app, routes}, - events::types, + event, test::fixtures::{self, future::Immediately as _}, }; @@ -33,26 +33,25 @@ async fn new_channel() { // Verify the semantics - let channels = app.channels().all().await.expect("always succeeds"); + let channels = app.channels().all(None).await.expect("always succeeds"); assert!(channels.contains(&response_channel)); let mut events = app .events() - .subscribe(types::ResumePoint::default()) + .subscribe(None) .await .expect("subscribing never fails") .filter(fixtures::filter::created()); - let types::ResumableEvent(_, event) = events + let event = events .next() .immediately() .await .expect("creation event published"); - assert_eq!(types::Sequence::default(), event.sequence); assert!(matches!( - event.data, - types::ChannelEventData::Created(event) + event.kind, + event::Kind::ChannelCreated(event) if event.channel == response_channel )); } diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 233518b..3297093 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -2,9 +2,10 @@ use axum::extract::{Json, Path, State}; use futures::stream::StreamExt; use crate::{ + channel, channel::routes, - events::{app, types}, - repo::channel, + event, + message::app::SendError, test::fixtures::{self, future::Immediately as _}, }; @@ -43,7 +44,7 @@ async fn messages_in_order() { let events = app .events() - .subscribe(types::ResumePoint::default()) + .subscribe(None) .await .expect("subscribing to a valid channel") .filter(fixtures::filter::messages()) @@ -51,13 +52,13 @@ async fn messages_in_order() { let events = events.collect::<Vec<_>>().immediately().await; - for ((sent_at, message), types::ResumableEvent(_, event)) in requests.into_iter().zip(events) { - assert_eq!(*sent_at, event.at); + for ((sent_at, message), event) in requests.into_iter().zip(events) { + assert_eq!(*sent_at, event.instant.at); assert!(matches!( - event.data, - types::ChannelEventData::Message(event_message) - if event_message.sender == sender - && event_message.message.body == message + event.kind, + event::Kind::MessageSent(event) + if event.message.sender == sender + && event.message.body == message )); } } @@ -76,7 +77,7 @@ async fn nonexistent_channel() { let request = routes::SendRequest { message: fixtures::message::propose(), }; - let routes::ErrorResponse(error) = routes::on_send( + let routes::SendErrorResponse(error) = routes::on_send( State(app), Path(channel.clone()), sent_at, @@ -90,6 +91,6 @@ async fn nonexistent_channel() { assert!(matches!( error, - app::EventsError::ChannelNotFound(error_channel) if channel == error_channel + SendError::ChannelNotFound(error_channel) if channel == error_channel )); } diff --git a/src/channel/snapshot.rs b/src/channel/snapshot.rs new file mode 100644 index 0000000..6462f25 --- /dev/null +++ b/src/channel/snapshot.rs @@ -0,0 +1,38 @@ +use super::{ + event::{Created, Event, Kind}, + Id, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Channel { + pub id: Id, + pub name: String, +} + +impl Channel { + fn apply(state: Option<Self>, event: Event) -> Option<Self> { + match (state, event.kind) { + (None, Kind::Created(event)) => Some(event.into()), + (Some(channel), Kind::Deleted(event)) if channel.id == event.channel => None, + (state, event) => panic!("invalid channel event {event:#?} for state {state:#?}"), + } + } +} + +impl FromIterator<Event> for Option<Channel> { + fn from_iter<I: IntoIterator<Item = Event>>(events: I) -> Self { + events.into_iter().fold(None, Channel::apply) + } +} + +impl From<&Created> for Channel { + fn from(event: &Created) -> Self { + event.channel.clone() + } +} + +impl From<Created> for Channel { + fn from(event: Created) -> Self { + event.channel + } +} @@ -10,7 +10,7 @@ use clap::Parser; use sqlx::sqlite::SqlitePool; use tokio::net; -use crate::{app::App, channel, clock, events, expire, login, repo::pool}; +use crate::{app::App, channel, clock, db, event, expire, login, message}; /// Command-line entry point for running the `hi` server. /// @@ -100,14 +100,19 @@ impl Args { } async fn pool(&self) -> sqlx::Result<SqlitePool> { - pool::prepare(&self.database_url).await + db::prepare(&self.database_url).await } } fn routers() -> Router<App> { - [channel::router(), events::router(), login::router()] - .into_iter() - .fold(Router::default(), Router::merge) + [ + channel::router(), + event::router(), + login::router(), + message::router(), + ] + .into_iter() + .fold(Router::default(), Router::merge) } fn started_msg(listener: &net::TcpListener) -> io::Result<String> { diff --git a/src/repo/pool.rs b/src/db.rs index b4aa6fc..93a1169 100644 --- a/src/repo/pool.rs +++ b/src/db.rs @@ -16,3 +16,27 @@ async fn create(database_url: &str) -> sqlx::Result<SqlitePool> { let pool = SqlitePoolOptions::new().connect_with(options).await?; Ok(pool) } + +pub trait NotFound { + type Ok; + fn not_found<E, F>(self, map: F) -> Result<Self::Ok, E> + where + E: From<sqlx::Error>, + F: FnOnce() -> E; +} + +impl<T> NotFound for Result<T, sqlx::Error> { + type Ok = T; + + fn not_found<E, F>(self, map: F) -> Result<T, E> + where + E: From<sqlx::Error>, + F: FnOnce() -> E, + { + match self { + Err(sqlx::Error::RowNotFound) => Err(map()), + Err(other) => Err(other.into()), + Ok(value) => Ok(value), + } + } +} diff --git a/src/error.rs b/src/error.rs index 6e797b4..8792a1d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -61,3 +61,11 @@ impl fmt::Display for Id { self.0.fmt(f) } } + +pub struct Unauthorized; + +impl IntoResponse for Unauthorized { + fn into_response(self) -> Response { + (StatusCode::UNAUTHORIZED, "unauthorized").into_response() + } +} diff --git a/src/event/app.rs b/src/event/app.rs new file mode 100644 index 0000000..d664ec7 --- /dev/null +++ b/src/event/app.rs @@ -0,0 +1,72 @@ +use futures::{ + future, + stream::{self, StreamExt as _}, + Stream, +}; +use itertools::Itertools as _; +use sqlx::sqlite::SqlitePool; + +use super::{broadcaster::Broadcaster, Event, Sequence, Sequenced}; +use crate::{ + channel::{self, repo::Provider as _}, + message::{self, repo::Provider as _}, +}; + +pub struct Events<'a> { + db: &'a SqlitePool, + events: &'a Broadcaster, +} + +impl<'a> Events<'a> { + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } + } + + pub async fn subscribe( + &self, + resume_at: Option<Sequence>, + ) -> Result<impl Stream<Item = Event> + std::fmt::Debug, sqlx::Error> { + // Subscribe before retrieving, to catch messages broadcast while we're + // querying the DB. We'll prune out duplicates later. + let live_messages = self.events.subscribe(); + + let mut tx = self.db.begin().await?; + + let channels = tx.channels().replay(resume_at).await?; + let channel_events = channels + .iter() + .map(channel::History::events) + .kmerge_by(Sequence::merge) + .filter(Sequence::after(resume_at)) + .map(Event::from); + + let messages = tx.messages().replay(resume_at).await?; + let message_events = messages + .iter() + .map(message::History::events) + .kmerge_by(Sequence::merge) + .filter(Sequence::after(resume_at)) + .map(Event::from); + + let replay_events = channel_events + .merge_by(message_events, Sequence::merge) + .collect::<Vec<_>>(); + let resume_live_at = replay_events.last().map(Sequenced::sequence); + + let replay = stream::iter(replay_events); + + let live_messages = live_messages + // Filtering on the broadcast resume point filters out messages + // before resume_at, and filters out messages duplicated from + // `replay_events`. + .flat_map(stream::iter) + .filter(Self::resume(resume_live_at)); + + Ok(replay.chain(live_messages)) + } + + fn resume(resume_at: Option<Sequence>) -> impl for<'m> FnMut(&'m Event) -> future::Ready<bool> { + let filter = Sequence::after(resume_at); + move |event| future::ready(filter(event)) + } +} diff --git a/src/event/broadcaster.rs b/src/event/broadcaster.rs new file mode 100644 index 0000000..3c4efac --- /dev/null +++ b/src/event/broadcaster.rs @@ -0,0 +1,3 @@ +use crate::broadcast; + +pub type Broadcaster = broadcast::Broadcaster<Vec<super::Event>>; diff --git a/src/events/extract.rs b/src/event/extract.rs index e3021e2..e3021e2 100644 --- a/src/events/extract.rs +++ b/src/event/extract.rs diff --git a/src/event/mod.rs b/src/event/mod.rs new file mode 100644 index 0000000..1349fe6 --- /dev/null +++ b/src/event/mod.rs @@ -0,0 +1,75 @@ +use crate::{channel, message}; + +pub mod app; +pub mod broadcaster; +mod extract; +pub mod repo; +mod routes; +mod sequence; + +pub use self::{ + routes::router, + sequence::{Instant, Sequence, Sequenced}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Event { + #[serde(flatten)] + pub instant: Instant, + #[serde(flatten)] + pub kind: Kind, +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + self.instant + } +} + +impl From<channel::Event> for Event { + fn from(event: channel::Event) -> Self { + Self { + instant: event.instant, + kind: event.kind.into(), + } + } +} + +impl From<message::Event> for Event { + fn from(event: message::Event) -> Self { + Self { + instant: event.instant(), + kind: event.kind.into(), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Kind { + #[serde(rename = "created")] + ChannelCreated(channel::event::Created), + #[serde(rename = "message")] + MessageSent(message::event::Sent), + MessageDeleted(message::event::Deleted), + #[serde(rename = "deleted")] + ChannelDeleted(channel::event::Deleted), +} + +impl From<channel::event::Kind> for Kind { + fn from(kind: channel::event::Kind) -> Self { + match kind { + channel::event::Kind::Created(created) => Self::ChannelCreated(created), + channel::event::Kind::Deleted(deleted) => Self::ChannelDeleted(deleted), + } + } +} + +impl From<message::event::Kind> for Kind { + fn from(kind: message::event::Kind) -> Self { + match kind { + message::event::Kind::Sent(created) => Self::MessageSent(created), + message::event::Kind::Deleted(deleted) => Self::MessageDeleted(deleted), + } + } +} diff --git a/src/event/repo.rs b/src/event/repo.rs new file mode 100644 index 0000000..40d6a53 --- /dev/null +++ b/src/event/repo.rs @@ -0,0 +1,50 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + clock::DateTime, + event::{Instant, Sequence}, +}; + +pub trait Provider { + fn sequence(&mut self) -> Sequences; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn sequence(&mut self) -> Sequences { + Sequences(self) + } +} + +pub struct Sequences<'t>(&'t mut SqliteConnection); + +impl<'c> Sequences<'c> { + pub async fn next(&mut self, at: &DateTime) -> Result<Instant, sqlx::Error> { + let next = sqlx::query_scalar!( + r#" + update event_sequence + set last_value = last_value + 1 + returning last_value as "next_value: Sequence" + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(Instant { + at: *at, + sequence: next, + }) + } + + pub async fn current(&mut self) -> Result<Sequence, sqlx::Error> { + let next = sqlx::query_scalar!( + r#" + select last_value as "last_value: Sequence" + from event_sequence + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } +} diff --git a/src/event/routes.rs b/src/event/routes.rs new file mode 100644 index 0000000..5b9c7e3 --- /dev/null +++ b/src/event/routes.rs @@ -0,0 +1,92 @@ +use axum::{ + extract::State, + response::{ + sse::{self, Sse}, + IntoResponse, Response, + }, + routing::get, + Router, +}; +use axum_extra::extract::Query; +use futures::stream::{Stream, StreamExt as _}; + +use super::{extract::LastEventId, Event}; +use crate::{ + app::App, + error::{Internal, Unauthorized}, + event::{Sequence, Sequenced as _}, + token::{app::ValidateError, extract::Identity}, +}; + +#[cfg(test)] +mod test; + +pub fn router() -> Router<App> { + Router::new().route("/api/events", get(events)) +} + +#[derive(Default, serde::Deserialize)] +struct EventsQuery { + resume_point: Option<Sequence>, +} + +async fn events( + State(app): State<App>, + identity: Identity, + last_event_id: Option<LastEventId<Sequence>>, + Query(query): Query<EventsQuery>, +) -> Result<Events<impl Stream<Item = Event> + std::fmt::Debug>, EventsError> { + let resume_at = last_event_id + .map(LastEventId::into_inner) + .or(query.resume_point); + + let stream = app.events().subscribe(resume_at).await?; + let stream = app.tokens().limit_stream(identity.token, stream).await?; + + Ok(Events(stream)) +} + +#[derive(Debug)] +struct Events<S>(S); + +impl<S> IntoResponse for Events<S> +where + S: Stream<Item = Event> + Send + 'static, +{ + fn into_response(self) -> Response { + let Self(stream) = self; + let stream = stream.map(sse::Event::try_from); + Sse::new(stream) + .keep_alive(sse::KeepAlive::default()) + .into_response() + } +} + +impl TryFrom<Event> for sse::Event { + type Error = serde_json::Error; + + fn try_from(event: Event) -> Result<Self, Self::Error> { + let id = serde_json::to_string(&event.sequence())?; + let data = serde_json::to_string_pretty(&event)?; + + let event = Self::default().id(id).data(data); + + Ok(event) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum EventsError { + DatabaseError(#[from] sqlx::Error), + ValidateError(#[from] ValidateError), +} + +impl IntoResponse for EventsError { + fn into_response(self) -> Response { + match self { + Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(), + other => Internal::from(other).into_response(), + } + } +} diff --git a/src/events/routes/test.rs b/src/event/routes/test.rs index 820192d..ba9953e 100644 --- a/src/events/routes/test.rs +++ b/src/event/routes/test.rs @@ -1,11 +1,12 @@ use axum::extract::State; +use axum_extra::extract::Query; use futures::{ future, stream::{self, StreamExt as _}, }; use crate::{ - events::{routes, types}, + event::{routes, Sequenced as _}, test::fixtures::{self, future::Immediately as _}, }; @@ -16,26 +17,26 @@ async fn includes_historical_message() { let app = fixtures::scratch_app().await; let sender = fixtures::login::create(&app).await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; // Call the endpoint let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None) + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) .await .expect("subscribe never fails"); // Verify the structure of the response. - let types::ResumableEvent(_, event) = events + let event = events .filter(fixtures::filter::messages()) .next() .immediately() .await .expect("delivered stored message"); - assert_eq!(message, event); + assert!(fixtures::event::message_sent(&event, &message)); } #[tokio::test] @@ -49,23 +50,24 @@ async fn includes_live_message() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); // Verify the semantics let sender = fixtures::login::create(&app).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; - let types::ResumableEvent(_, event) = events + let event = events .filter(fixtures::filter::messages()) .next() .immediately() .await .expect("delivered live message"); - assert_eq!(message, event); + assert!(fixtures::event::message_sent(&event, &message)); } #[tokio::test] @@ -85,7 +87,7 @@ async fn includes_multiple_channels() { let app = app.clone(); let sender = sender.clone(); let channel = channel.clone(); - async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } + async move { fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await } }) .collect::<Vec<_>>() .await; @@ -94,7 +96,7 @@ async fn includes_multiple_channels() { let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None) + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) .await .expect("subscribe never fails"); @@ -110,7 +112,7 @@ async fn includes_multiple_channels() { for message in &messages { assert!(events .iter() - .any(|types::ResumableEvent(_, event)| { event == message })); + .any(|event| fixtures::event::message_sent(event, message))); } } @@ -123,33 +125,38 @@ async fn sequential_messages() { let sender = fixtures::login::create(&app).await; let messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; // Call the endpoint let subscriber_creds = fixtures::login::create_with_password(&app).await; let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None) + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) .await .expect("subscribe never fails"); // Verify the structure of the response. - let mut events = - events.filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))); + let mut events = events.filter(|event| { + future::ready( + messages + .iter() + .any(|message| fixtures::event::message_sent(event, message)), + ) + }); // Verify delivery in order for message in &messages { - let types::ResumableEvent(_, event) = events + let event = events .next() .immediately() .await .expect("undelivered messages remaining"); - assert_eq!(message, &event); + assert!(fixtures::event::message_sent(&event, message)); } } @@ -161,11 +168,11 @@ async fn resumes_from() { let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app).await; - let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + let initial_message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; let later_messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; // Call the endpoint @@ -175,26 +182,36 @@ async fn resumes_from() { let resume_at = { // First subscription - let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); - let types::ResumableEvent(last_event_id, event) = events + let event = events .filter(fixtures::filter::messages()) .next() .immediately() .await .expect("delivered events"); - assert_eq!(initial_message, event); + assert!(fixtures::event::message_sent(&event, &initial_message)); - last_event_id + event.sequence() }; // Resume after disconnect - let routes::Events(resumed) = routes::events(State(app), subscriber, Some(resume_at.into())) - .await - .expect("subscribe never fails"); + let routes::Events(resumed) = routes::events( + State(app), + subscriber, + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); // Verify the structure of the response. @@ -207,7 +224,7 @@ async fn resumes_from() { for message in &later_messages { assert!(events .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + .any(|event| fixtures::event::message_sent(event, message))); } } @@ -242,14 +259,19 @@ async fn serial_resume() { let resume_at = { let initial_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel_b, &sender, &fixtures::now()).await, ]; // First subscription - let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); let events = events .filter(fixtures::filter::messages()) @@ -261,12 +283,12 @@ async fn serial_resume() { for message in &initial_messages { assert!(events .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + .any(|event| fixtures::event::message_sent(event, message))); } - let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); + let event = events.last().expect("this vec is non-empty"); - id.to_owned() + event.sequence() }; // Resume after disconnect @@ -275,8 +297,8 @@ async fn serial_resume() { // Note that channel_b does not appear here. The buggy behaviour // would be masked if channel_b happened to send a new message // into the resumed event stream. - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, ]; // Second subscription @@ -284,6 +306,7 @@ async fn serial_resume() { State(app.clone()), subscriber.clone(), Some(resume_at.into()), + Query::default(), ) .await .expect("subscribe never fails"); @@ -298,12 +321,12 @@ async fn serial_resume() { for message in &resume_messages { assert!(events .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + .any(|event| fixtures::event::message_sent(event, message))); } - let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); + let event = events.last().expect("this vec is non-empty"); - id.to_owned() + event.sequence() }; // Resume after disconnect a second time @@ -312,8 +335,8 @@ async fn serial_resume() { // problem. The resume point should before both of these messages, but // after _all_ prior messages. let final_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel_b, &sender, &fixtures::now()).await, ]; // Third subscription @@ -321,6 +344,7 @@ async fn serial_resume() { State(app.clone()), subscriber.clone(), Some(resume_at.into()), + Query::default(), ) .await .expect("subscribe never fails"); @@ -337,7 +361,7 @@ async fn serial_resume() { for message in &final_messages { assert!(events .iter() - .any(|types::ResumableEvent(_, event)| event == message)); + .any(|event| fixtures::event::message_sent(event, message))); } }; } @@ -356,26 +380,31 @@ async fn terminates_on_token_expiry() { let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; - let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); // Verify the resulting stream's behaviour - app.logins() + app.tokens() .expire(&fixtures::now()) .await .expect("expiring tokens succeeds"); // These should not be delivered. let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; assert!(events - .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))) + .filter(|event| future::ready( + messages + .iter() + .any(|message| fixtures::event::message_sent(event, message)) + )) .next() .immediately() .await @@ -398,26 +427,35 @@ async fn terminates_on_logout() { let subscriber = fixtures::identity::from_token(&app, &subscriber_token, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) - .await - .expect("subscribe never fails"); + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); // Verify the resulting stream's behaviour - app.logins() + app.tokens() .logout(&subscriber.token) .await .expect("expiring tokens succeeds"); // These should not be delivered. let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; assert!(events - .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))) + .filter(|event| future::ready( + messages + .iter() + .any(|message| fixtures::event::message_sent(event, message)) + )) .next() .immediately() .await diff --git a/src/event/sequence.rs b/src/event/sequence.rs new file mode 100644 index 0000000..fbe3711 --- /dev/null +++ b/src/event/sequence.rs @@ -0,0 +1,90 @@ +use std::fmt; + +use crate::clock::DateTime; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Instant { + pub at: DateTime, + #[serde(skip)] + pub sequence: Sequence, +} + +impl From<Instant> for Sequence { + fn from(instant: Instant) -> Self { + instant.sequence + } +} + +#[derive( + Clone, + Copy, + Debug, + Eq, + Ord, + PartialEq, + PartialOrd, + serde::Deserialize, + serde::Serialize, + sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Sequence(i64); + +impl fmt::Display for Sequence { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(value) = self; + value.fmt(f) + } +} + +impl Sequence { + pub fn up_to<E>(resume_point: Option<Self>) -> impl for<'e> Fn(&'e E) -> bool + where + E: Sequenced, + { + move |event| resume_point.map_or(true, |resume_point| event.sequence() <= resume_point) + } + + pub fn after<E>(resume_point: Option<Self>) -> impl for<'e> Fn(&'e E) -> bool + where + E: Sequenced, + { + move |event| resume_point < Some(event.sequence()) + } + + pub fn start_from<E>(resume_point: Self) -> impl for<'e> Fn(&'e E) -> bool + where + E: Sequenced, + { + move |event| resume_point <= event.sequence() + } + + pub fn merge<E>(a: &E, b: &E) -> bool + where + E: Sequenced, + { + a.sequence() < b.sequence() + } +} + +pub trait Sequenced { + fn instant(&self) -> Instant; + + fn sequence(&self) -> Sequence { + self.instant().into() + } +} + +impl<E> Sequenced for &E +where + E: Sequenced, +{ + fn instant(&self) -> Instant { + (*self).instant() + } + + fn sequence(&self) -> Sequence { + (*self).sequence() + } +} diff --git a/src/events/app.rs b/src/events/app.rs deleted file mode 100644 index db7f430..0000000 --- a/src/events/app.rs +++ /dev/null @@ -1,163 +0,0 @@ -use std::collections::BTreeMap; - -use chrono::TimeDelta; -use futures::{ - future, - stream::{self, StreamExt as _}, - Stream, -}; -use sqlx::sqlite::SqlitePool; - -use super::{ - broadcaster::Broadcaster, - repo::message::Provider as _, - types::{self, ChannelEvent, ResumePoint}, -}; -use crate::{ - clock::DateTime, - repo::{ - channel::{self, Provider as _}, - error::NotFound as _, - login::Login, - }, -}; - -pub struct Events<'a> { - db: &'a SqlitePool, - events: &'a Broadcaster, -} - -impl<'a> Events<'a> { - pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { - Self { db, events } - } - - pub async fn send( - &self, - login: &Login, - channel: &channel::Id, - body: &str, - sent_at: &DateTime, - ) -> Result<types::ChannelEvent, EventsError> { - let mut tx = self.db.begin().await?; - let channel = tx - .channels() - .by_id(channel) - .await - .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; - let event = tx - .message_events() - .create(login, &channel, body, sent_at) - .await?; - tx.commit().await?; - - self.events.broadcast(&event); - Ok(event) - } - - pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { - // Somewhat arbitrarily, expire after 90 days. - let expire_at = relative_to.to_owned() - TimeDelta::days(90); - - let mut tx = self.db.begin().await?; - let expired = tx.message_events().expired(&expire_at).await?; - - let mut events = Vec::with_capacity(expired.len()); - for (channel, message) in expired { - let sequence = tx.message_events().assign_sequence(&channel).await?; - let event = tx - .message_events() - .delete_expired(&channel, &message, sequence, relative_to) - .await?; - events.push(event); - } - - tx.commit().await?; - - for event in events { - self.events.broadcast(&event); - } - - Ok(()) - } - - pub async fn subscribe( - &self, - resume_at: ResumePoint, - ) -> Result<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug, sqlx::Error> { - let mut tx = self.db.begin().await?; - let channels = tx.channels().all().await?; - - let created_events = { - let resume_at = resume_at.clone(); - let channels = channels.clone(); - stream::iter( - channels - .into_iter() - .map(ChannelEvent::created) - .filter(move |event| resume_at.not_after(event)), - ) - }; - - // Subscribe before retrieving, to catch messages broadcast while we're - // querying the DB. We'll prune out duplicates later. - let live_messages = self.events.subscribe(); - - let mut replays = BTreeMap::new(); - let mut resume_live_at = resume_at.clone(); - for channel in channels { - let replay = tx - .message_events() - .replay(&channel, resume_at.get(&channel.id)) - .await?; - - if let Some(last) = replay.last() { - resume_live_at.advance(last); - } - - replays.insert(channel.id.clone(), replay); - } - - let replay = stream::select_all(replays.into_values().map(stream::iter)); - - // no skip_expired or resume transforms for stored_messages, as it's - // constructed not to contain messages meeting either criterion. - // - // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; - // * resume is redundant with the resume_at argument to - // `tx.broadcasts().replay(…)`. - let live_messages = live_messages - // Filtering on the broadcast resume point filters out messages - // before resume_at, and filters out messages duplicated from - // stored_messages. - .filter(Self::resume(resume_live_at)); - - Ok(created_events.chain(replay).chain(live_messages).scan( - resume_at, - |resume_point, event| { - match event.data { - types::ChannelEventData::Deleted(_) => resume_point.forget(&event), - _ => resume_point.advance(&event), - } - - let event = types::ResumableEvent(resume_point.clone(), event); - - future::ready(Some(event)) - }, - )) - } - - fn resume( - resume_at: ResumePoint, - ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> { - move |event| future::ready(resume_at.not_after(event)) - } -} - -#[derive(Debug, thiserror::Error)] -pub enum EventsError { - #[error("channel {0} not found")] - ChannelNotFound(channel::Id), - #[error(transparent)] - DatabaseError(#[from] sqlx::Error), -} diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs deleted file mode 100644 index 6b664cb..0000000 --- a/src/events/broadcaster.rs +++ /dev/null @@ -1,3 +0,0 @@ -use crate::{broadcast, events::types}; - -pub type Broadcaster = broadcast::Broadcaster<types::ChannelEvent>; diff --git a/src/events/mod.rs b/src/events/mod.rs deleted file mode 100644 index 711ae64..0000000 --- a/src/events/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod app; -pub mod broadcaster; -mod extract; -pub mod repo; -mod routes; -pub mod types; - -pub use self::routes::router; diff --git a/src/events/repo/message.rs b/src/events/repo/message.rs deleted file mode 100644 index f8bae2b..0000000 --- a/src/events/repo/message.rs +++ /dev/null @@ -1,198 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - clock::DateTime, - events::types::{self, Sequence}, - repo::{ - channel::{self, Channel}, - login::{self, Login}, - message::{self, Message}, - }, -}; - -pub trait Provider { - fn message_events(&mut self) -> Events; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn message_events(&mut self) -> Events { - Events(self) - } -} - -pub struct Events<'t>(&'t mut SqliteConnection); - -impl<'c> Events<'c> { - pub async fn create( - &mut self, - sender: &Login, - channel: &Channel, - body: &str, - sent_at: &DateTime, - ) -> Result<types::ChannelEvent, sqlx::Error> { - let sequence = self.assign_sequence(channel).await?; - - let id = message::Id::generate(); - - let message = sqlx::query!( - r#" - insert into message - (id, channel, sequence, sender, body, sent_at) - values ($1, $2, $3, $4, $5, $6) - returning - id as "id: message::Id", - sequence as "sequence: Sequence", - sender as "sender: login::Id", - body, - sent_at as "sent_at: DateTime" - "#, - id, - channel.id, - sequence, - sender.id, - body, - sent_at, - ) - .map(|row| types::ChannelEvent { - sequence: row.sequence, - at: row.sent_at, - data: types::MessageEvent { - channel: channel.clone(), - sender: sender.clone(), - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_one(&mut *self.0) - .await?; - - Ok(message) - } - - pub async fn assign_sequence(&mut self, channel: &Channel) -> Result<Sequence, sqlx::Error> { - let next = sqlx::query_scalar!( - r#" - update channel - set last_sequence = last_sequence + 1 - where id = $1 - returning last_sequence as "next_sequence: Sequence" - "#, - channel.id, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) - } - - pub async fn delete_expired( - &mut self, - channel: &Channel, - message: &message::Id, - sequence: Sequence, - deleted_at: &DateTime, - ) -> Result<types::ChannelEvent, sqlx::Error> { - sqlx::query_scalar!( - r#" - delete from message - where id = $1 - returning 1 as "row: i64" - "#, - message, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(types::ChannelEvent { - sequence, - at: *deleted_at, - data: types::MessageDeletedEvent { - channel: channel.clone(), - message: message.clone(), - } - .into(), - }) - } - - pub async fn expired( - &mut self, - expire_at: &DateTime, - ) -> Result<Vec<(Channel, message::Id)>, sqlx::Error> { - let messages = sqlx::query!( - r#" - select - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - channel.created_at as "channel_created_at: DateTime", - message.id as "message: message::Id" - from message - join channel on message.channel = channel.id - join login as sender on message.sender = sender.id - where sent_at < $1 - "#, - expire_at, - ) - .map(|row| { - ( - Channel { - id: row.channel_id, - name: row.channel_name, - created_at: row.channel_created_at, - }, - row.message, - ) - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(messages) - } - - pub async fn replay( - &mut self, - channel: &Channel, - resume_at: Option<Sequence>, - ) -> Result<Vec<types::ChannelEvent>, sqlx::Error> { - let events = sqlx::query!( - r#" - select - message.id as "id: message::Id", - sequence as "sequence: Sequence", - login.id as "sender_id: login::Id", - login.name as sender_name, - message.body, - message.sent_at as "sent_at: DateTime" - from message - join login on message.sender = login.id - where channel = $1 - and coalesce(sequence > $2, true) - order by sequence asc - "#, - channel.id, - resume_at, - ) - .map(|row| types::ChannelEvent { - sequence: row.sequence, - at: row.sent_at, - data: types::MessageEvent { - channel: channel.clone(), - sender: login::Login { - id: row.sender_id, - name: row.sender_name, - }, - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(events) - } -} diff --git a/src/events/repo/mod.rs b/src/events/repo/mod.rs deleted file mode 100644 index e216a50..0000000 --- a/src/events/repo/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod message; diff --git a/src/events/routes.rs b/src/events/routes.rs deleted file mode 100644 index ec9dae2..0000000 --- a/src/events/routes.rs +++ /dev/null @@ -1,69 +0,0 @@ -use axum::{ - extract::State, - response::{ - sse::{self, Sse}, - IntoResponse, Response, - }, - routing::get, - Router, -}; -use futures::stream::{Stream, StreamExt as _}; - -use super::{ - extract::LastEventId, - types::{self, ResumePoint}, -}; -use crate::{app::App, error::Internal, login::extract::Identity}; - -#[cfg(test)] -mod test; - -pub fn router() -> Router<App> { - Router::new().route("/api/events", get(events)) -} - -async fn events( - State(app): State<App>, - identity: Identity, - last_event_id: Option<LastEventId<ResumePoint>>, -) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> { - let resume_at = last_event_id - .map(LastEventId::into_inner) - .unwrap_or_default(); - - let stream = app.events().subscribe(resume_at).await?; - let stream = app.logins().limit_stream(identity.token, stream); - - Ok(Events(stream)) -} - -#[derive(Debug)] -struct Events<S>(S); - -impl<S> IntoResponse for Events<S> -where - S: Stream<Item = types::ResumableEvent> + Send + 'static, -{ - fn into_response(self) -> Response { - let Self(stream) = self; - let stream = stream.map(sse::Event::try_from); - Sse::new(stream) - .keep_alive(sse::KeepAlive::default()) - .into_response() - } -} - -impl TryFrom<types::ResumableEvent> for sse::Event { - type Error = serde_json::Error; - - fn try_from(value: types::ResumableEvent) -> Result<Self, Self::Error> { - let types::ResumableEvent(resume_at, data) = value; - - let id = serde_json::to_string(&resume_at)?; - let data = serde_json::to_string_pretty(&data)?; - - let event = Self::default().id(id).data(data); - - Ok(event) - } -} diff --git a/src/events/types.rs b/src/events/types.rs deleted file mode 100644 index d954512..0000000 --- a/src/events/types.rs +++ /dev/null @@ -1,170 +0,0 @@ -use std::collections::BTreeMap; - -use crate::{ - clock::DateTime, - repo::{ - channel::{self, Channel}, - login::Login, - message, - }, -}; - -#[derive( - Debug, - Default, - Eq, - Ord, - PartialEq, - PartialOrd, - Clone, - Copy, - serde::Serialize, - serde::Deserialize, - sqlx::Type, -)] -#[serde(transparent)] -#[sqlx(transparent)] -pub struct Sequence(i64); - -impl Sequence { - pub fn next(self) -> Self { - let Self(current) = self; - Self(current + 1) - } -} - -// For the purposes of event replay, a resume point is a vector of resume -// elements. A resume element associates a channel (by ID) with the latest event -// seen in that channel so far. Replaying the event stream can restart at a -// predictable point - hence the name. These values can be serialized and sent -// to the client as JSON dicts, then rehydrated to recover the resume point at a -// later time. -// -// Using a sorted map ensures that there is a canonical representation for -// each resume point. -#[derive(Clone, Debug, Default, PartialEq, PartialOrd, serde::Deserialize, serde::Serialize)] -#[serde(transparent)] -pub struct ResumePoint(BTreeMap<channel::Id, Sequence>); - -impl ResumePoint { - pub fn advance<'e>(&mut self, event: impl Into<ResumeElement<'e>>) { - let Self(elements) = self; - let ResumeElement(channel, sequence) = event.into(); - elements.insert(channel.clone(), sequence); - } - - pub fn forget<'e>(&mut self, event: impl Into<ResumeElement<'e>>) { - let Self(elements) = self; - let ResumeElement(channel, _) = event.into(); - elements.remove(channel); - } - - pub fn get(&self, channel: &channel::Id) -> Option<Sequence> { - let Self(elements) = self; - elements.get(channel).copied() - } - - pub fn not_after<'e>(&self, event: impl Into<ResumeElement<'e>>) -> bool { - let Self(elements) = self; - let ResumeElement(channel, sequence) = event.into(); - - elements - .get(channel) - .map_or(true, |resume_at| resume_at < &sequence) - } -} - -pub struct ResumeElement<'i>(&'i channel::Id, Sequence); - -#[derive(Clone, Debug)] -pub struct ResumableEvent(pub ResumePoint, pub ChannelEvent); - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct ChannelEvent { - #[serde(skip)] - pub sequence: Sequence, - pub at: DateTime, - #[serde(flatten)] - pub data: ChannelEventData, -} - -impl ChannelEvent { - pub fn created(channel: Channel) -> Self { - Self { - at: channel.created_at, - sequence: Sequence::default(), - data: CreatedEvent { channel }.into(), - } - } - - pub fn channel_id(&self) -> &channel::Id { - match &self.data { - ChannelEventData::Created(event) => &event.channel.id, - ChannelEventData::Message(event) => &event.channel.id, - ChannelEventData::MessageDeleted(event) => &event.channel.id, - ChannelEventData::Deleted(event) => &event.channel, - } - } -} - -impl<'c> From<&'c ChannelEvent> for ResumeElement<'c> { - fn from(event: &'c ChannelEvent) -> Self { - Self(event.channel_id(), event.sequence) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ChannelEventData { - Created(CreatedEvent), - Message(MessageEvent), - MessageDeleted(MessageDeletedEvent), - Deleted(DeletedEvent), -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct CreatedEvent { - pub channel: Channel, -} - -impl From<CreatedEvent> for ChannelEventData { - fn from(event: CreatedEvent) -> Self { - Self::Created(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageEvent { - pub channel: Channel, - pub sender: Login, - pub message: message::Message, -} - -impl From<MessageEvent> for ChannelEventData { - fn from(event: MessageEvent) -> Self { - Self::Message(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageDeletedEvent { - pub channel: Channel, - pub message: message::Id, -} - -impl From<MessageDeletedEvent> for ChannelEventData { - fn from(event: MessageDeletedEvent) -> Self { - Self::MessageDeleted(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct DeletedEvent { - pub channel: channel::Id, -} - -impl From<DeletedEvent> for ChannelEventData { - fn from(event: DeletedEvent) -> Self { - Self::Deleted(event) - } -} diff --git a/src/expire.rs b/src/expire.rs index 16006d1..e50bcb4 100644 --- a/src/expire.rs +++ b/src/expire.rs @@ -13,8 +13,8 @@ pub async fn middleware( req: Request, next: Next, ) -> Result<Response, Internal> { - app.logins().expire(&expired_at).await?; - app.events().expire(&expired_at).await?; + app.tokens().expire(&expired_at).await?; + app.messages().expire(&expired_at).await?; app.channels().expire(&expired_at).await?; Ok(next.run(req).await) } @@ -7,12 +7,13 @@ mod broadcast; mod channel; pub mod cli; mod clock; +mod db; mod error; -mod events; +mod event; mod expire; mod id; mod login; -mod password; -mod repo; +mod message; #[cfg(test)] mod test; +mod token; diff --git a/src/login/app.rs b/src/login/app.rs index 182c62c..15adb31 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -1,57 +1,25 @@ -use chrono::TimeDelta; -use futures::{ - future, - stream::{self, StreamExt as _}, - Stream, -}; use sqlx::sqlite::SqlitePool; -use super::{broadcaster::Broadcaster, extract::IdentitySecret, repo::auth::Provider as _, types}; -use crate::{ - clock::DateTime, - password::Password, - repo::{ - error::NotFound as _, - login::{Login, Provider as _}, - token::{self, Provider as _}, - }, -}; +use crate::event::{repo::Provider as _, Sequence}; + +#[cfg(test)] +use super::{repo::Provider as _, Login, Password}; pub struct Logins<'a> { db: &'a SqlitePool, - logins: &'a Broadcaster, } impl<'a> Logins<'a> { - pub const fn new(db: &'a SqlitePool, logins: &'a Broadcaster) -> Self { - Self { db, logins } + pub const fn new(db: &'a SqlitePool) -> Self { + Self { db } } - pub async fn login( - &self, - name: &str, - password: &Password, - login_at: &DateTime, - ) -> Result<IdentitySecret, LoginError> { + pub async fn boot_point(&self) -> Result<Sequence, sqlx::Error> { let mut tx = self.db.begin().await?; - - let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? { - if stored_hash.verify(password)? { - // Password verified; use the login. - login - } else { - // Password NOT verified. - return Err(LoginError::Rejected); - } - } else { - let password_hash = password.hash()?; - tx.logins().create(name, &password_hash).await? - }; - - let token = tx.tokens().issue(&login, login_at).await?; + let sequence = tx.sequence().current().await?; tx.commit().await?; - Ok(token) + Ok(sequence) } #[cfg(test)] @@ -64,82 +32,6 @@ impl<'a> Logins<'a> { Ok(login) } - - pub async fn validate( - &self, - secret: &IdentitySecret, - used_at: &DateTime, - ) -> Result<(token::Id, Login), ValidateError> { - let mut tx = self.db.begin().await?; - let login = tx - .tokens() - .validate(secret, used_at) - .await - .not_found(|| ValidateError::InvalidToken)?; - tx.commit().await?; - - Ok(login) - } - - pub fn limit_stream<E>( - &self, - token: token::Id, - events: impl Stream<Item = E> + std::fmt::Debug, - ) -> impl Stream<Item = E> + std::fmt::Debug - where - E: std::fmt::Debug, - { - let token_events = self - .logins - .subscribe() - .filter(move |event| future::ready(event.token == token)) - .map(|_| GuardedEvent::TokenRevoked); - - let events = events.map(|event| GuardedEvent::Event(event)); - - stream::select(token_events, events).scan((), |(), event| { - future::ready(match event { - GuardedEvent::Event(event) => Some(event), - GuardedEvent::TokenRevoked => None, - }) - }) - } - - pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { - // Somewhat arbitrarily, expire after 7 days. - let expire_at = relative_to.to_owned() - TimeDelta::days(7); - - let mut tx = self.db.begin().await?; - let tokens = tx.tokens().expire(&expire_at).await?; - tx.commit().await?; - - for event in tokens.into_iter().map(types::TokenRevoked::from) { - self.logins.broadcast(&event); - } - - Ok(()) - } - - pub async fn logout(&self, token: &token::Id) -> Result<(), ValidateError> { - let mut tx = self.db.begin().await?; - tx.tokens().revoke(token).await?; - tx.commit().await?; - - self.logins - .broadcast(&types::TokenRevoked::from(token.clone())); - - Ok(()) - } -} - -#[derive(Debug, thiserror::Error)] -pub enum LoginError { - #[error("invalid login")] - Rejected, - #[error(transparent)] - DatabaseError(#[from] sqlx::Error), - #[error(transparent)] - PasswordHashError(#[from] password_hash::Error), } #[cfg(test)] @@ -149,17 +41,3 @@ pub enum CreateError { DatabaseError(#[from] sqlx::Error), PasswordHashError(#[from] password_hash::Error), } - -#[derive(Debug, thiserror::Error)] -pub enum ValidateError { - #[error("invalid token")] - InvalidToken, - #[error(transparent)] - DatabaseError(#[from] sqlx::Error), -} - -#[derive(Debug)] -enum GuardedEvent<E> { - TokenRevoked, - Event(E), -} diff --git a/src/login/broadcaster.rs b/src/login/broadcaster.rs deleted file mode 100644 index 8e1fb3a..0000000 --- a/src/login/broadcaster.rs +++ /dev/null @@ -1,3 +0,0 @@ -use crate::{broadcast, login::types}; - -pub type Broadcaster = broadcast::Broadcaster<types::TokenRevoked>; diff --git a/src/login/extract.rs b/src/login/extract.rs index b585565..c2d97f2 100644 --- a/src/login/extract.rs +++ b/src/login/extract.rs @@ -1,182 +1,15 @@ -use std::fmt; +use axum::{extract::FromRequestParts, http::request::Parts}; -use axum::{ - extract::{FromRequestParts, State}, - http::{request::Parts, StatusCode}, - response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, -}; -use axum_extra::extract::cookie::{Cookie, CookieJar}; - -use crate::{ - app::App, - clock::RequestedAt, - error::Internal, - login::app::ValidateError, - repo::{login::Login, token}, -}; - -// The usage pattern here - receive the extractor as an argument, return it in -// the response - is heavily modelled after CookieJar's own intended usage. -#[derive(Clone)] -pub struct IdentityToken { - cookies: CookieJar, -} - -impl fmt::Debug for IdentityToken { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("IdentityToken") - .field( - "identity", - &self.cookies.get(IDENTITY_COOKIE).map(|_| "********"), - ) - .finish() - } -} - -impl IdentityToken { - // Creates a new, unpopulated identity token store. - #[cfg(test)] - pub fn new() -> Self { - Self { - cookies: CookieJar::new(), - } - } - - // Get the identity secret sent in the request, if any. If the identity - // was not sent, or if it has previously been [clear]ed, then this will - // return [None]. If the identity has previously been [set], then this - // will return that secret, regardless of what the request originally - // included. - pub fn secret(&self) -> Option<IdentitySecret> { - self.cookies - .get(IDENTITY_COOKIE) - .map(Cookie::value) - .map(IdentitySecret::from) - } - - // Positively set the identity secret, and ensure that it will be sent - // back to the client when this extractor is included in a response. - pub fn set(self, secret: impl Into<IdentitySecret>) -> Self { - let IdentitySecret(secret) = secret.into(); - let identity_cookie = Cookie::build((IDENTITY_COOKIE, secret)) - .http_only(true) - .path("/api/") - .permanent() - .build(); - - Self { - cookies: self.cookies.add(identity_cookie), - } - } - - // Remove the identity secret and ensure that it will be cleared when this - // extractor is included in a response. - pub fn clear(self) -> Self { - Self { - cookies: self.cookies.remove(IDENTITY_COOKIE), - } - } -} - -const IDENTITY_COOKIE: &str = "identity"; - -#[async_trait::async_trait] -impl<S> FromRequestParts<S> for IdentityToken -where - S: Send + Sync, -{ - type Rejection = <CookieJar as FromRequestParts<S>>::Rejection; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { - let cookies = CookieJar::from_request_parts(parts, state).await?; - Ok(Self { cookies }) - } -} - -impl IntoResponseParts for IdentityToken { - type Error = <CookieJar as IntoResponseParts>::Error; - - fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> { - let Self { cookies } = self; - cookies.into_response_parts(res) - } -} - -#[derive(sqlx::Type)] -#[sqlx(transparent)] -pub struct IdentitySecret(String); - -impl fmt::Debug for IdentitySecret { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("IdentityToken").field(&"********").finish() - } -} - -impl<S> From<S> for IdentitySecret -where - S: Into<String>, -{ - fn from(value: S) -> Self { - Self(value.into()) - } -} - -#[derive(Clone, Debug)] -pub struct Identity { - pub token: token::Id, - pub login: Login, -} +use super::Login; +use crate::{app::App, token::extract::Identity}; #[async_trait::async_trait] -impl FromRequestParts<App> for Identity { - type Rejection = LoginError<Internal>; +impl FromRequestParts<App> for Login { + type Rejection = <Identity as FromRequestParts<App>>::Rejection; async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> { - // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on - // stable), the following can be replaced: - // - // ``` - // let Ok(identity_token) = IdentityToken::from_request_parts( - // parts, - // state, - // ).await; - // ``` - let identity_token = IdentityToken::from_request_parts(parts, state).await?; - let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?; - - let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; - - let app = State::<App>::from_request_parts(parts, state).await?; - match app.logins().validate(&secret, &used_at).await { - Ok((token, login)) => Ok(Identity { token, login }), - Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), - Err(other) => Err(other.into()), - } - } -} - -pub enum LoginError<E> { - Failure(E), - Unauthorized, -} - -impl<E> IntoResponse for LoginError<E> -where - E: IntoResponse, -{ - fn into_response(self) -> Response { - match self { - Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(), - Self::Failure(e) => e.into_response(), - } - } -} + let identity = Identity::from_request_parts(parts, state).await?; -impl<E> From<E> for LoginError<Internal> -where - E: Into<Internal>, -{ - fn from(err: E) -> Self { - Self::Failure(err.into()) + Ok(identity.login) } } diff --git a/src/login/id.rs b/src/login/id.rs new file mode 100644 index 0000000..c46d697 --- /dev/null +++ b/src/login/id.rs @@ -0,0 +1,24 @@ +use crate::id::Id as BaseId; + +// Stable identifier for a [Login]. Prefixed with `L`. +#[derive(Clone, Debug, Eq, PartialEq, sqlx::Type, serde::Serialize)] +#[sqlx(transparent)] +pub struct Id(BaseId); + +impl From<BaseId> for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("L") + } +} + +impl std::fmt::Display for Id { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/login/mod.rs b/src/login/mod.rs index 6ae82ac..65e3ada 100644 --- a/src/login/mod.rs +++ b/src/login/mod.rs @@ -1,8 +1,21 @@ -pub use self::routes::router; - pub mod app; -pub mod broadcaster; pub mod extract; -mod repo; +mod id; +pub mod password; +pub mod repo; mod routes; -pub mod types; + +pub use self::{id::Id, password::Password, routes::router}; + +// This also implements FromRequestParts (see `./extract.rs`). As a result, it +// can be used as an extractor for endpoints that want to require login, or for +// endpoints that need to behave differently depending on whether the client is +// or is not logged in. +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Login { + pub id: Id, + pub name: String, + // The omission of the hashed password is deliberate, to minimize the + // chance that it ends up tangled up in debug output or in some other chunk + // of logic elsewhere. +} diff --git a/src/password.rs b/src/login/password.rs index da3930f..da3930f 100644 --- a/src/password.rs +++ b/src/login/password.rs diff --git a/src/login/repo.rs b/src/login/repo.rs new file mode 100644 index 0000000..d1a02c4 --- /dev/null +++ b/src/login/repo.rs @@ -0,0 +1,50 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::login::{password::StoredHash, Id, Login}; + +pub trait Provider { + fn logins(&mut self) -> Logins; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn logins(&mut self) -> Logins { + Logins(self) + } +} + +pub struct Logins<'t>(&'t mut SqliteConnection); + +impl<'c> Logins<'c> { + pub async fn create( + &mut self, + name: &str, + password_hash: &StoredHash, + ) -> Result<Login, sqlx::Error> { + let id = Id::generate(); + + let login = sqlx::query_as!( + Login, + r#" + insert or fail + into login (id, name, password_hash) + values ($1, $2, $3) + returning + id as "id: Id", + name + "#, + id, + name, + password_hash, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(login) + } +} + +impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { + fn from(tx: &'t mut SqliteConnection) -> Self { + Self(tx) + } +} diff --git a/src/login/repo/mod.rs b/src/login/repo/mod.rs deleted file mode 100644 index 0e4a05d..0000000 --- a/src/login/repo/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod auth; diff --git a/src/login/routes.rs b/src/login/routes.rs index 8d9e938..0874cc3 100644 --- a/src/login/routes.rs +++ b/src/login/routes.rs @@ -7,11 +7,13 @@ use axum::{ }; use crate::{ - app::App, clock::RequestedAt, error::Internal, password::Password, repo::login::Login, + app::App, + clock::RequestedAt, + error::{Internal, Unauthorized}, + login::{Login, Password}, + token::{app, extract::IdentityToken}, }; -use super::{app, extract::IdentityToken}; - #[cfg(test)] mod test; @@ -22,13 +24,18 @@ pub fn router() -> Router<App> { .route("/api/auth/logout", post(on_logout)) } -async fn boot(login: Login) -> Boot { - Boot { login } +async fn boot(State(app): State<App>, login: Login) -> Result<Boot, Internal> { + let resume_point = app.logins().boot_point().await?; + Ok(Boot { + login, + resume_point: resume_point.to_string(), + }) } #[derive(serde::Serialize)] struct Boot { login: Login, + resume_point: String, } impl IntoResponse for Boot { @@ -50,7 +57,7 @@ async fn on_login( Json(request): Json<LoginRequest>, ) -> Result<(IdentityToken, StatusCode), LoginError> { let token = app - .logins() + .tokens() .login(&request.name, &request.password, &now) .await .map_err(LoginError)?; @@ -66,6 +73,7 @@ impl IntoResponse for LoginError { let Self(error) = self; match error { app::LoginError::Rejected => { + // not error::Unauthorized due to differing messaging (StatusCode::UNAUTHORIZED, "invalid name or password").into_response() } other => Internal::from(other).into_response(), @@ -85,8 +93,8 @@ async fn on_logout( Json(LogoutRequest {}): Json<LogoutRequest>, ) -> Result<(IdentityToken, StatusCode), LogoutError> { if let Some(secret) = identity.secret() { - let (token, _) = app.logins().validate(&secret, &now).await?; - app.logins().logout(&token).await?; + let (token, _) = app.tokens().validate(&secret, &now).await?; + app.tokens().logout(&token).await?; } let identity = identity.clear(); @@ -103,9 +111,7 @@ enum LogoutError { impl IntoResponse for LogoutError { fn into_response(self) -> Response { match self { - error @ Self::ValidateError(app::ValidateError::InvalidToken) => { - (StatusCode::UNAUTHORIZED, error.to_string()).into_response() - } + Self::ValidateError(app::ValidateError::InvalidToken) => Unauthorized.into_response(), other => Internal::from(other).into_response(), } } diff --git a/src/login/routes/test/boot.rs b/src/login/routes/test/boot.rs index dee554f..9655354 100644 --- a/src/login/routes/test/boot.rs +++ b/src/login/routes/test/boot.rs @@ -1,9 +1,14 @@ +use axum::extract::State; + use crate::{login::routes, test::fixtures}; #[tokio::test] async fn returns_identity() { + let app = fixtures::scratch_app().await; let login = fixtures::login::fictitious(); - let response = routes::boot(login.clone()).await; + let response = routes::boot(State(app), login.clone()) + .await + .expect("boot always succeeds"); assert_eq!(login, response.login); } diff --git a/src/login/routes/test/login.rs b/src/login/routes/test/login.rs index 81653ff..3c82738 100644 --- a/src/login/routes/test/login.rs +++ b/src/login/routes/test/login.rs @@ -3,10 +3,7 @@ use axum::{ http::StatusCode, }; -use crate::{ - login::{app, routes}, - test::fixtures, -}; +use crate::{login::routes, test::fixtures, token::app}; #[tokio::test] async fn new_identity() { @@ -37,7 +34,7 @@ async fn new_identity() { let validated_at = fixtures::now(); let (_, validated) = app - .logins() + .tokens() .validate(&secret, &validated_at) .await .expect("identity secret is valid"); @@ -74,7 +71,7 @@ async fn existing_identity() { let validated_at = fixtures::now(); let (_, validated_login) = app - .logins() + .tokens() .validate(&secret, &validated_at) .await .expect("identity secret is valid"); @@ -127,14 +124,14 @@ async fn token_expires() { // Verify the semantics let expired_at = fixtures::now(); - app.logins() + app.tokens() .expire(&expired_at) .await .expect("expiring tokens never fails"); let verified_at = fixtures::now(); let error = app - .logins() + .tokens() .validate(&secret, &verified_at) .await .expect_err("validating an expired token"); diff --git a/src/login/routes/test/logout.rs b/src/login/routes/test/logout.rs index 20b0d55..42b2534 100644 --- a/src/login/routes/test/logout.rs +++ b/src/login/routes/test/logout.rs @@ -3,10 +3,7 @@ use axum::{ http::StatusCode, }; -use crate::{ - login::{app, routes}, - test::fixtures, -}; +use crate::{login::routes, test::fixtures, token::app}; #[tokio::test] async fn successful() { @@ -37,7 +34,7 @@ async fn successful() { // Verify the semantics let error = app - .logins() + .tokens() .validate(&secret, &now) .await .expect_err("secret is invalid"); diff --git a/src/message/app.rs b/src/message/app.rs new file mode 100644 index 0000000..33ea8ad --- /dev/null +++ b/src/message/app.rs @@ -0,0 +1,115 @@ +use chrono::TimeDelta; +use itertools::Itertools; +use sqlx::sqlite::SqlitePool; + +use super::{repo::Provider as _, Id, Message}; +use crate::{ + channel::{self, repo::Provider as _}, + clock::DateTime, + db::NotFound as _, + event::{broadcaster::Broadcaster, repo::Provider as _, Event, Sequence}, + login::Login, +}; + +pub struct Messages<'a> { + db: &'a SqlitePool, + events: &'a Broadcaster, +} + +impl<'a> Messages<'a> { + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } + } + + pub async fn send( + &self, + channel: &channel::Id, + sender: &Login, + sent_at: &DateTime, + body: &str, + ) -> Result<Message, SendError> { + let mut tx = self.db.begin().await?; + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| SendError::ChannelNotFound(channel.clone()))?; + let sent = tx.sequence().next(sent_at).await?; + let message = tx + .messages() + .create(&channel.snapshot(), sender, &sent, body) + .await?; + tx.commit().await?; + + self.events + .broadcast(message.events().map(Event::from).collect::<Vec<_>>()); + + Ok(message.snapshot()) + } + + pub async fn delete(&self, message: &Id, deleted_at: &DateTime) -> Result<(), DeleteError> { + let mut tx = self.db.begin().await?; + let deleted = tx.sequence().next(deleted_at).await?; + let message = tx.messages().delete(message, &deleted).await?; + tx.commit().await?; + + self.events.broadcast( + message + .events() + .filter(Sequence::start_from(deleted.sequence)) + .map(Event::from) + .collect::<Vec<_>>(), + ); + + Ok(()) + } + + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 90 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(90); + + let mut tx = self.db.begin().await?; + + let expired = tx.messages().expired(&expire_at).await?; + let mut events = Vec::with_capacity(expired.len()); + for message in expired { + let deleted = tx.sequence().next(relative_to).await?; + let message = tx.messages().delete(&message, &deleted).await?; + events.push( + message + .events() + .filter(Sequence::start_from(deleted.sequence)), + ); + } + + tx.commit().await?; + + self.events.broadcast( + events + .into_iter() + .kmerge_by(Sequence::merge) + .map(Event::from) + .collect::<Vec<_>>(), + ); + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum SendError { + #[error("channel {0} not found")] + ChannelNotFound(channel::Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum DeleteError { + #[error("channel {0} not found")] + ChannelNotFound(channel::Id), + #[error("message {0} not found")] + NotFound(Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} diff --git a/src/message/event.rs b/src/message/event.rs new file mode 100644 index 0000000..66db9b0 --- /dev/null +++ b/src/message/event.rs @@ -0,0 +1,71 @@ +use super::{snapshot::Message, Id}; +use crate::{ + channel::Channel, + event::{Instant, Sequenced}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Event { + #[serde(flatten)] + pub kind: Kind, +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + self.kind.instant() + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Kind { + Sent(Sent), + Deleted(Deleted), +} + +impl Sequenced for Kind { + fn instant(&self) -> Instant { + match self { + Self::Sent(sent) => sent.instant(), + Self::Deleted(deleted) => deleted.instant(), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Sent { + #[serde(flatten)] + pub message: Message, +} + +impl Sequenced for Sent { + fn instant(&self) -> Instant { + self.message.sent + } +} + +impl From<Sent> for Kind { + fn from(event: Sent) -> Self { + Self::Sent(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Deleted { + #[serde(flatten)] + pub instant: Instant, + pub channel: Channel, + pub message: Id, +} + +impl Sequenced for Deleted { + fn instant(&self) -> Instant { + self.instant + } +} + +impl From<Deleted> for Kind { + fn from(event: Deleted) -> Self { + Self::Deleted(event) + } +} diff --git a/src/message/history.rs b/src/message/history.rs new file mode 100644 index 0000000..89fc6b1 --- /dev/null +++ b/src/message/history.rs @@ -0,0 +1,41 @@ +use super::{ + event::{Deleted, Event, Sent}, + Message, +}; +use crate::event::Instant; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct History { + pub message: Message, + pub deleted: Option<Instant>, +} + +impl History { + fn sent(&self) -> Event { + Event { + kind: Sent { + message: self.message.clone(), + } + .into(), + } + } + + fn deleted(&self) -> Option<Event> { + self.deleted.map(|instant| Event { + kind: Deleted { + instant, + channel: self.message.channel.clone(), + message: self.message.id.clone(), + } + .into(), + }) + } + + pub fn events(&self) -> impl Iterator<Item = Event> { + [self.sent()].into_iter().chain(self.deleted()) + } + + pub fn snapshot(&self) -> Message { + self.message.clone() + } +} diff --git a/src/repo/message.rs b/src/message/id.rs index a1f73d5..385b103 100644 --- a/src/repo/message.rs +++ b/src/message/id.rs @@ -25,9 +25,3 @@ impl fmt::Display for Id { self.0.fmt(f) } } - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Message { - pub id: Id, - pub body: String, -} diff --git a/src/message/mod.rs b/src/message/mod.rs new file mode 100644 index 0000000..a8f51ab --- /dev/null +++ b/src/message/mod.rs @@ -0,0 +1,9 @@ +pub mod app; +pub mod event; +mod history; +mod id; +pub mod repo; +mod routes; +mod snapshot; + +pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Message}; diff --git a/src/message/repo.rs b/src/message/repo.rs new file mode 100644 index 0000000..fc835c8 --- /dev/null +++ b/src/message/repo.rs @@ -0,0 +1,247 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use super::{snapshot::Message, History, Id}; +use crate::{ + channel::{self, Channel}, + clock::DateTime, + event::{Instant, Sequence}, + login::{self, Login}, +}; + +pub trait Provider { + fn messages(&mut self) -> Messages; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn messages(&mut self) -> Messages { + Messages(self) + } +} + +pub struct Messages<'t>(&'t mut SqliteConnection); + +impl<'c> Messages<'c> { + pub async fn create( + &mut self, + channel: &Channel, + sender: &Login, + sent: &Instant, + body: &str, + ) -> Result<History, sqlx::Error> { + let id = Id::generate(); + + let message = sqlx::query!( + r#" + insert into message + (id, channel, sender, sent_at, sent_sequence, body) + values ($1, $2, $3, $4, $5, $6) + returning + id as "id: Id", + body + "#, + id, + channel.id, + sender.id, + sent.at, + sent.sequence, + body, + ) + .map(|row| History { + message: Message { + sent: *sent, + channel: channel.clone(), + sender: sender.clone(), + id: row.id, + body: row.body, + }, + deleted: None, + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn in_channel( + &mut self, + channel: &Channel, + resume_at: Option<Sequence>, + ) -> Result<Vec<History>, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + sender.id as "sender_id: login::Id", + sender.name as "sender_name", + message.id as "id: Id", + message.body, + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where channel.id = $1 + and coalesce(message.sent_sequence <= $2, true) + order by message.sent_sequence + "#, + channel.id, + resume_at, + ) + .map(|row| History { + message: Message { + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, + channel: Channel { + id: row.channel_id, + name: row.channel_name, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + id: row.id, + body: row.body, + }, + deleted: None, + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } + + async fn by_id(&mut self, message: &Id) -> Result<History, sqlx::Error> { + let message = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + sender.id as "sender_id: login::Id", + sender.name as "sender_name", + message.id as "id: Id", + message.body, + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where message.id = $1 + "#, + message, + ) + .map(|row| History { + message: Message { + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, + channel: Channel { + id: row.channel_id, + name: row.channel_name, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + id: row.id, + body: row.body, + }, + deleted: None, + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn delete( + &mut self, + message: &Id, + deleted: &Instant, + ) -> Result<History, sqlx::Error> { + let history = self.by_id(message).await?; + + sqlx::query_scalar!( + r#" + delete from message + where + id = $1 + returning 1 as "deleted: i64" + "#, + history.message.id, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(History { + deleted: Some(*deleted), + ..history + }) + } + + pub async fn expired(&mut self, expire_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> { + let messages = sqlx::query_scalar!( + r#" + select + id as "message: Id" + from message + where sent_at < $1 + "#, + expire_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } + + pub async fn replay( + &mut self, + resume_at: Option<Sequence>, + ) -> Result<Vec<History>, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + sender.id as "sender_id: login::Id", + sender.name as "sender_name", + message.id as "id: Id", + message.body, + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where coalesce(message.sent_sequence > $1, true) + "#, + resume_at, + ) + .map(|row| History { + message: Message { + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, + channel: Channel { + id: row.channel_id, + name: row.channel_name, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + id: row.id, + body: row.body, + }, + deleted: None, + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } +} diff --git a/src/message/routes.rs b/src/message/routes.rs new file mode 100644 index 0000000..29fe3d7 --- /dev/null +++ b/src/message/routes.rs @@ -0,0 +1,46 @@ +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::{IntoResponse, Response}, + routing::delete, + Router, +}; + +use crate::{ + app::App, + clock::RequestedAt, + error::Internal, + login::Login, + message::{self, app::DeleteError}, +}; + +pub fn router() -> Router<App> { + Router::new().route("/api/messages/:message", delete(on_delete)) +} + +async fn on_delete( + State(app): State<App>, + Path(message): Path<message::Id>, + RequestedAt(deleted_at): RequestedAt, + _: Login, +) -> Result<StatusCode, ErrorResponse> { + app.messages().delete(&message, &deleted_at).await?; + + Ok(StatusCode::ACCEPTED) +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +struct ErrorResponse(#[from] DeleteError); + +impl IntoResponse for ErrorResponse { + fn into_response(self) -> Response { + let Self(error) = self; + match error { + not_found @ (DeleteError::ChannelNotFound(_) | DeleteError::NotFound(_)) => { + (StatusCode::NOT_FOUND, not_found.to_string()).into_response() + } + other => Internal::from(other).into_response(), + } + } +} diff --git a/src/message/snapshot.rs b/src/message/snapshot.rs new file mode 100644 index 0000000..522c1aa --- /dev/null +++ b/src/message/snapshot.rs @@ -0,0 +1,76 @@ +use super::{ + event::{Event, Kind, Sent}, + Id, +}; +use crate::{channel::Channel, event::Instant, login::Login}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(into = "self::serialize::Message")] +pub struct Message { + #[serde(skip)] + pub sent: Instant, + pub channel: Channel, + pub sender: Login, + pub id: Id, + pub body: String, +} + +mod serialize { + use crate::{channel::Channel, login::Login, message::Id}; + + #[derive(serde::Serialize)] + pub struct Message { + channel: Channel, + sender: Login, + #[allow(clippy::struct_field_names)] + // Deliberately redundant with the module path; this produces a specific serialization. + message: MessageData, + } + + #[derive(serde::Serialize)] + pub struct MessageData { + id: Id, + body: String, + } + + impl From<super::Message> for Message { + fn from(message: super::Message) -> Self { + Self { + channel: message.channel, + sender: message.sender, + message: MessageData { + id: message.id, + body: message.body, + }, + } + } + } +} + +impl Message { + fn apply(state: Option<Self>, event: Event) -> Option<Self> { + match (state, event.kind) { + (None, Kind::Sent(event)) => Some(event.into()), + (Some(message), Kind::Deleted(event)) if message.id == event.message => None, + (state, event) => panic!("invalid message event {event:#?} for state {state:#?}"), + } + } +} + +impl FromIterator<Event> for Option<Message> { + fn from_iter<I: IntoIterator<Item = Event>>(events: I) -> Self { + events.into_iter().fold(None, Message::apply) + } +} + +impl From<&Sent> for Message { + fn from(event: &Sent) -> Self { + event.message.clone() + } +} + +impl From<Sent> for Message { + fn from(event: Sent) -> Self { + event.message + } +} diff --git a/src/repo/channel.rs b/src/repo/channel.rs deleted file mode 100644 index 3c7468f..0000000 --- a/src/repo/channel.rs +++ /dev/null @@ -1,179 +0,0 @@ -use std::fmt; - -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - clock::DateTime, - events::types::{self, Sequence}, - id::Id as BaseId, -}; - -pub trait Provider { - fn channels(&mut self) -> Channels; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn channels(&mut self) -> Channels { - Channels(self) - } -} - -pub struct Channels<'t>(&'t mut SqliteConnection); - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Channel { - pub id: Id, - pub name: String, - #[serde(skip)] - pub created_at: DateTime, -} - -impl<'c> Channels<'c> { - pub async fn create( - &mut self, - name: &str, - created_at: &DateTime, - ) -> Result<Channel, sqlx::Error> { - let id = Id::generate(); - let sequence = Sequence::default(); - - let channel = sqlx::query_as!( - Channel, - r#" - insert - into channel (id, name, created_at, last_sequence) - values ($1, $2, $3, $4) - returning - id as "id: Id", - name, - created_at as "created_at: DateTime" - "#, - id, - name, - created_at, - sequence, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(channel) - } - - pub async fn by_id(&mut self, channel: &Id) -> Result<Channel, sqlx::Error> { - let channel = sqlx::query_as!( - Channel, - r#" - select - id as "id: Id", - name, - created_at as "created_at: DateTime" - from channel - where id = $1 - "#, - channel, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(channel) - } - - pub async fn all(&mut self) -> Result<Vec<Channel>, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, - r#" - select - id as "id: Id", - name, - created_at as "created_at: DateTime" - from channel - order by channel.name - "#, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(channels) - } - - pub async fn delete_expired( - &mut self, - channel: &Channel, - sequence: Sequence, - deleted_at: &DateTime, - ) -> Result<types::ChannelEvent, sqlx::Error> { - let channel = channel.id.clone(); - sqlx::query_scalar!( - r#" - delete from channel - where id = $1 - returning 1 as "row: i64" - "#, - channel, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(types::ChannelEvent { - sequence, - at: *deleted_at, - data: types::DeletedEvent { channel }.into(), - }) - } - - pub async fn expired(&mut self, expired_at: &DateTime) -> Result<Vec<Channel>, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, - r#" - select - channel.id as "id: Id", - channel.name, - channel.created_at as "created_at: DateTime" - from channel - left join message - where created_at < $1 - and message.id is null - "#, - expired_at, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(channels) - } -} - -// Stable identifier for a [Channel]. Prefixed with `C`. -#[derive( - Clone, - Debug, - Eq, - Hash, - Ord, - PartialEq, - PartialOrd, - sqlx::Type, - serde::Deserialize, - serde::Serialize, -)] -#[sqlx(transparent)] -#[serde(transparent)] -pub struct Id(BaseId); - -impl From<BaseId> for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("C") - } -} - -impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} diff --git a/src/repo/error.rs b/src/repo/error.rs deleted file mode 100644 index a5961e2..0000000 --- a/src/repo/error.rs +++ /dev/null @@ -1,23 +0,0 @@ -pub trait NotFound { - type Ok; - fn not_found<E, F>(self, map: F) -> Result<Self::Ok, E> - where - E: From<sqlx::Error>, - F: FnOnce() -> E; -} - -impl<T> NotFound for Result<T, sqlx::Error> { - type Ok = T; - - fn not_found<E, F>(self, map: F) -> Result<T, E> - where - E: From<sqlx::Error>, - F: FnOnce() -> E, - { - match self { - Err(sqlx::Error::RowNotFound) => Err(map()), - Err(other) => Err(other.into()), - Ok(value) => Ok(value), - } - } -} diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs deleted file mode 100644 index ab61106..0000000 --- a/src/repo/login/extract.rs +++ /dev/null @@ -1,15 +0,0 @@ -use axum::{extract::FromRequestParts, http::request::Parts}; - -use super::Login; -use crate::{app::App, login::extract::Identity}; - -#[async_trait::async_trait] -impl FromRequestParts<App> for Login { - type Rejection = <Identity as FromRequestParts<App>>::Rejection; - - async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> { - let identity = Identity::from_request_parts(parts, state).await?; - - Ok(identity.login) - } -} diff --git a/src/repo/login/mod.rs b/src/repo/login/mod.rs deleted file mode 100644 index a1b4c6f..0000000 --- a/src/repo/login/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod extract; -mod store; - -pub use self::store::{Id, Login, Provider}; diff --git a/src/repo/login/store.rs b/src/repo/login/store.rs deleted file mode 100644 index b485941..0000000 --- a/src/repo/login/store.rs +++ /dev/null @@ -1,86 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{id::Id as BaseId, password::StoredHash}; - -pub trait Provider { - fn logins(&mut self) -> Logins; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn logins(&mut self) -> Logins { - Logins(self) - } -} - -pub struct Logins<'t>(&'t mut SqliteConnection); - -// This also implements FromRequestParts (see `./extract.rs`). As a result, it -// can be used as an extractor for endpoints that want to require login, or for -// endpoints that need to behave differently depending on whether the client is -// or is not logged in. -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Login { - pub id: Id, - pub name: String, - // The omission of the hashed password is deliberate, to minimize the - // chance that it ends up tangled up in debug output or in some other chunk - // of logic elsewhere. -} - -impl<'c> Logins<'c> { - pub async fn create( - &mut self, - name: &str, - password_hash: &StoredHash, - ) -> Result<Login, sqlx::Error> { - let id = Id::generate(); - - let login = sqlx::query_as!( - Login, - r#" - insert or fail - into login (id, name, password_hash) - values ($1, $2, $3) - returning - id as "id: Id", - name - "#, - id, - name, - password_hash, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(login) - } -} - -impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { - fn from(tx: &'t mut SqliteConnection) -> Self { - Self(tx) - } -} - -// Stable identifier for a [Login]. Prefixed with `L`. -#[derive(Clone, Debug, Eq, PartialEq, sqlx::Type, serde::Serialize)] -#[sqlx(transparent)] -pub struct Id(BaseId); - -impl From<BaseId> for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("L") - } -} - -impl std::fmt::Display for Id { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} diff --git a/src/repo/mod.rs b/src/repo/mod.rs deleted file mode 100644 index cb9d7c8..0000000 --- a/src/repo/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod channel; -pub mod error; -pub mod login; -pub mod message; -pub mod pool; -pub mod token; diff --git a/src/test/fixtures/channel.rs b/src/test/fixtures/channel.rs index 8744470..b678717 100644 --- a/src/test/fixtures/channel.rs +++ b/src/test/fixtures/channel.rs @@ -4,7 +4,7 @@ use faker_rand::{ }; use rand; -use crate::{app::App, clock::RequestedAt, repo::channel::Channel}; +use crate::{app::App, channel::Channel, clock::RequestedAt}; pub async fn create(app: &App, created_at: &RequestedAt) -> Channel { let name = propose(); diff --git a/src/test/fixtures/event.rs b/src/test/fixtures/event.rs new file mode 100644 index 0000000..09f0490 --- /dev/null +++ b/src/test/fixtures/event.rs @@ -0,0 +1,11 @@ +use crate::{ + event::{Event, Kind}, + message::Message, +}; + +pub fn message_sent(event: &Event, message: &Message) -> bool { + matches!( + &event.kind, + Kind::MessageSent(event) if message == &event.into() + ) +} diff --git a/src/test/fixtures/filter.rs b/src/test/fixtures/filter.rs index fbebced..6e62aea 100644 --- a/src/test/fixtures/filter.rs +++ b/src/test/fixtures/filter.rs @@ -1,15 +1,11 @@ use futures::future; -use crate::events::types; +use crate::event::{Event, Kind}; -pub fn messages() -> impl FnMut(&types::ResumableEvent) -> future::Ready<bool> { - |types::ResumableEvent(_, event)| { - future::ready(matches!(event.data, types::ChannelEventData::Message(_))) - } +pub fn messages() -> impl FnMut(&Event) -> future::Ready<bool> { + |event| future::ready(matches!(event.kind, Kind::MessageSent(_))) } -pub fn created() -> impl FnMut(&types::ResumableEvent) -> future::Ready<bool> { - |types::ResumableEvent(_, event)| { - future::ready(matches!(event.data, types::ChannelEventData::Created(_))) - } +pub fn created() -> impl FnMut(&Event) -> future::Ready<bool> { + |event| future::ready(matches!(event.kind, Kind::ChannelCreated(_))) } diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs index 633fb8a..56b4ffa 100644 --- a/src/test/fixtures/identity.rs +++ b/src/test/fixtures/identity.rs @@ -3,8 +3,11 @@ use uuid::Uuid; use crate::{ app::App, clock::RequestedAt, - login::extract::{Identity, IdentitySecret, IdentityToken}, - password::Password, + login::Password, + token::{ + extract::{Identity, IdentityToken}, + Secret, + }, }; pub fn not_logged_in() -> IdentityToken { @@ -14,7 +17,7 @@ pub fn not_logged_in() -> IdentityToken { pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt) -> IdentityToken { let (name, password) = login; let token = app - .logins() + .tokens() .login(name, password, now) .await .expect("should succeed given known-valid credentials"); @@ -25,7 +28,7 @@ pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt) pub async fn from_token(app: &App, token: &IdentityToken, issued_at: &RequestedAt) -> Identity { let secret = token.secret().expect("identity token has a secret"); let (token, login) = app - .logins() + .tokens() .validate(&secret, issued_at) .await .expect("always validates newly-issued secret"); @@ -38,7 +41,7 @@ pub async fn identity(app: &App, login: &(String, Password), issued_at: &Request from_token(app, &secret, issued_at).await } -pub fn secret(identity: &IdentityToken) -> IdentitySecret { +pub fn secret(identity: &IdentityToken) -> Secret { identity.secret().expect("identity contained a secret") } diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs index d6a321b..00c2789 100644 --- a/src/test/fixtures/login.rs +++ b/src/test/fixtures/login.rs @@ -3,8 +3,7 @@ use uuid::Uuid; use crate::{ app::App, - password::Password, - repo::login::{self, Login}, + login::{self, Login, Password}, }; pub async fn create_with_password(app: &App) -> (String, Password) { diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs index bfca8cd..381b10b 100644 --- a/src/test/fixtures/message.rs +++ b/src/test/fixtures/message.rs @@ -1,22 +1,12 @@ use faker_rand::lorem::Paragraphs; -use crate::{ - app::App, - clock::RequestedAt, - events::types, - repo::{channel::Channel, login::Login}, -}; +use crate::{app::App, channel::Channel, clock::RequestedAt, login::Login, message::Message}; -pub async fn send( - app: &App, - login: &Login, - channel: &Channel, - sent_at: &RequestedAt, -) -> types::ChannelEvent { +pub async fn send(app: &App, channel: &Channel, login: &Login, sent_at: &RequestedAt) -> Message { let body = propose(); - app.events() - .send(login, &channel.id, &body, sent_at) + app.messages() + .send(&channel.id, login, sent_at, &body) .await .expect("should succeed if the channel exists") } diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index d1dd0c3..c5efa9b 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -1,8 +1,9 @@ use chrono::{TimeDelta, Utc}; -use crate::{app::App, clock::RequestedAt, repo::pool}; +use crate::{app::App, clock::RequestedAt, db}; pub mod channel; +pub mod event; pub mod filter; pub mod future; pub mod identity; @@ -10,7 +11,7 @@ pub mod login; pub mod message; pub async fn scratch_app() -> App { - let pool = pool::prepare("sqlite::memory:") + let pool = db::prepare("sqlite::memory:") .await .expect("setting up in-memory sqlite database"); App::from(pool) diff --git a/src/token/app.rs b/src/token/app.rs new file mode 100644 index 0000000..5c4fcd5 --- /dev/null +++ b/src/token/app.rs @@ -0,0 +1,170 @@ +use chrono::TimeDelta; +use futures::{ + future, + stream::{self, StreamExt as _}, + Stream, +}; +use sqlx::sqlite::SqlitePool; + +use super::{ + broadcaster::Broadcaster, event, repo::auth::Provider as _, repo::Provider as _, Id, Secret, +}; +use crate::{ + clock::DateTime, + db::NotFound as _, + login::{repo::Provider as _, Login, Password}, +}; + +pub struct Tokens<'a> { + db: &'a SqlitePool, + tokens: &'a Broadcaster, +} + +impl<'a> Tokens<'a> { + pub const fn new(db: &'a SqlitePool, tokens: &'a Broadcaster) -> Self { + Self { db, tokens } + } + pub async fn login( + &self, + name: &str, + password: &Password, + login_at: &DateTime, + ) -> Result<Secret, LoginError> { + let mut tx = self.db.begin().await?; + + let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? { + if stored_hash.verify(password)? { + // Password verified; use the login. + login + } else { + // Password NOT verified. + return Err(LoginError::Rejected); + } + } else { + let password_hash = password.hash()?; + tx.logins().create(name, &password_hash).await? + }; + + let token = tx.tokens().issue(&login, login_at).await?; + tx.commit().await?; + + Ok(token) + } + + pub async fn validate( + &self, + secret: &Secret, + used_at: &DateTime, + ) -> Result<(Id, Login), ValidateError> { + let mut tx = self.db.begin().await?; + let login = tx + .tokens() + .validate(secret, used_at) + .await + .not_found(|| ValidateError::InvalidToken)?; + tx.commit().await?; + + Ok(login) + } + + pub async fn limit_stream<E>( + &self, + token: Id, + events: impl Stream<Item = E> + std::fmt::Debug, + ) -> Result<impl Stream<Item = E> + std::fmt::Debug, ValidateError> + where + E: std::fmt::Debug, + { + // Subscribe, first. + let token_events = self.tokens.subscribe(); + + // Check that the token is valid at this point in time, second. If it is, then + // any future revocations will appear in the subscription. If not, bail now. + // + // It's possible, otherwise, to get to this point with a token that _was_ valid + // at the start of the request, but which was invalided _before_ the + // `subscribe()` call. In that case, the corresponding revocation event will + // simply be missed, since the `token_events` stream subscribed after the fact. + // This check cancels guarding the stream here. + // + // Yes, this is a weird niche edge case. Most things don't double-check, because + // they aren't expected to run long enough for the token's revocation to + // matter. Supervising a stream, on the other hand, will run for a + // _long_ time; if we miss the race here, we'll never actually carry out the + // supervision. + let mut tx = self.db.begin().await?; + tx.tokens() + .require(&token) + .await + .not_found(|| ValidateError::InvalidToken)?; + tx.commit().await?; + + // Then construct the guarded stream. First, project both streams into + // `GuardedEvent`. + let token_events = token_events + .filter(move |event| future::ready(event.token == token)) + .map(|_| GuardedEvent::TokenRevoked); + let events = events.map(|event| GuardedEvent::Event(event)); + + // Merge the two streams, then unproject them, stopping at + // `GuardedEvent::TokenRevoked`. + let stream = stream::select(token_events, events).scan((), |(), event| { + future::ready(match event { + GuardedEvent::Event(event) => Some(event), + GuardedEvent::TokenRevoked => None, + }) + }); + + Ok(stream) + } + + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 7 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(7); + + let mut tx = self.db.begin().await?; + let tokens = tx.tokens().expire(&expire_at).await?; + tx.commit().await?; + + for event in tokens.into_iter().map(event::TokenRevoked::from) { + self.tokens.broadcast(event); + } + + Ok(()) + } + + pub async fn logout(&self, token: &Id) -> Result<(), ValidateError> { + let mut tx = self.db.begin().await?; + tx.tokens().revoke(token).await?; + tx.commit().await?; + + self.tokens + .broadcast(event::TokenRevoked::from(token.clone())); + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum LoginError { + #[error("invalid login")] + Rejected, + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), + #[error(transparent)] + PasswordHashError(#[from] password_hash::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum ValidateError { + #[error("invalid token")] + InvalidToken, + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} + +#[derive(Debug)] +enum GuardedEvent<E> { + TokenRevoked, + Event(E), +} diff --git a/src/token/broadcaster.rs b/src/token/broadcaster.rs new file mode 100644 index 0000000..8e2e006 --- /dev/null +++ b/src/token/broadcaster.rs @@ -0,0 +1,4 @@ +use super::event; +use crate::broadcast; + +pub type Broadcaster = broadcast::Broadcaster<event::TokenRevoked>; diff --git a/src/login/types.rs b/src/token/event.rs index 7c7cbf9..d53d436 100644 --- a/src/login/types.rs +++ b/src/token/event.rs @@ -1,4 +1,4 @@ -use crate::repo::token; +use crate::token; #[derive(Clone, Debug)] pub struct TokenRevoked { diff --git a/src/token/extract/identity.rs b/src/token/extract/identity.rs new file mode 100644 index 0000000..60ad220 --- /dev/null +++ b/src/token/extract/identity.rs @@ -0,0 +1,75 @@ +use axum::{ + extract::{FromRequestParts, State}, + http::request::Parts, + response::{IntoResponse, Response}, +}; + +use super::IdentityToken; + +use crate::{ + app::App, + clock::RequestedAt, + error::{Internal, Unauthorized}, + login::Login, + token::{self, app::ValidateError}, +}; + +#[derive(Clone, Debug)] +pub struct Identity { + pub token: token::Id, + pub login: Login, +} + +#[async_trait::async_trait] +impl FromRequestParts<App> for Identity { + type Rejection = LoginError<Internal>; + + async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> { + // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on + // stable), the following can be replaced: + // + // ``` + // let Ok(identity_token) = IdentityToken::from_request_parts( + // parts, + // state, + // ).await; + // ``` + let identity_token = IdentityToken::from_request_parts(parts, state).await?; + let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?; + + let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; + + let app = State::<App>::from_request_parts(parts, state).await?; + match app.tokens().validate(&secret, &used_at).await { + Ok((token, login)) => Ok(Identity { token, login }), + Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), + Err(other) => Err(other.into()), + } + } +} + +pub enum LoginError<E> { + Failure(E), + Unauthorized, +} + +impl<E> IntoResponse for LoginError<E> +where + E: IntoResponse, +{ + fn into_response(self) -> Response { + match self { + Self::Unauthorized => Unauthorized.into_response(), + Self::Failure(e) => e.into_response(), + } + } +} + +impl<E> From<E> for LoginError<Internal> +where + E: Into<Internal>, +{ + fn from(err: E) -> Self { + Self::Failure(err.into()) + } +} diff --git a/src/token/extract/identity_token.rs b/src/token/extract/identity_token.rs new file mode 100644 index 0000000..0a47a43 --- /dev/null +++ b/src/token/extract/identity_token.rs @@ -0,0 +1,94 @@ +use std::fmt; + +use axum::{ + extract::FromRequestParts, + http::request::Parts, + response::{IntoResponseParts, ResponseParts}, +}; +use axum_extra::extract::cookie::{Cookie, CookieJar}; + +use crate::token::Secret; + +// The usage pattern here - receive the extractor as an argument, return it in +// the response - is heavily modelled after CookieJar's own intended usage. +#[derive(Clone)] +pub struct IdentityToken { + cookies: CookieJar, +} + +impl fmt::Debug for IdentityToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IdentityToken") + .field("identity", &self.secret()) + .finish() + } +} + +impl IdentityToken { + // Creates a new, unpopulated identity token store. + #[cfg(test)] + pub fn new() -> Self { + Self { + cookies: CookieJar::new(), + } + } + + // Get the identity secret sent in the request, if any. If the identity + // was not sent, or if it has previously been [clear]ed, then this will + // return [None]. If the identity has previously been [set], then this + // will return that secret, regardless of what the request originally + // included. + pub fn secret(&self) -> Option<Secret> { + self.cookies + .get(IDENTITY_COOKIE) + .map(Cookie::value) + .map(Secret::from) + } + + // Positively set the identity secret, and ensure that it will be sent + // back to the client when this extractor is included in a response. + pub fn set(self, secret: impl Into<Secret>) -> Self { + let secret = secret.into().reveal(); + let identity_cookie = Cookie::build((IDENTITY_COOKIE, secret)) + .http_only(true) + .path("/api/") + .permanent() + .build(); + + Self { + cookies: self.cookies.add(identity_cookie), + } + } + + // Remove the identity secret and ensure that it will be cleared when this + // extractor is included in a response. + pub fn clear(self) -> Self { + Self { + cookies: self.cookies.remove(IDENTITY_COOKIE), + } + } +} + +const IDENTITY_COOKIE: &str = "identity"; + +#[async_trait::async_trait] +impl<S> FromRequestParts<S> for IdentityToken +where + S: Send + Sync, +{ + type Rejection = <CookieJar as FromRequestParts<S>>::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { + let cookies = CookieJar::from_request_parts(parts, state).await?; + Ok(Self { cookies }) + } +} + +impl IntoResponseParts for IdentityToken { + type Error = <CookieJar as IntoResponseParts>::Error; + + fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> { + let Self { cookies } = self; + cookies.into_response_parts(res) + } +} diff --git a/src/token/extract/mod.rs b/src/token/extract/mod.rs new file mode 100644 index 0000000..b4800ae --- /dev/null +++ b/src/token/extract/mod.rs @@ -0,0 +1,4 @@ +mod identity; +mod identity_token; + +pub use self::{identity::Identity, identity_token::IdentityToken}; diff --git a/src/token/id.rs b/src/token/id.rs new file mode 100644 index 0000000..9ef063c --- /dev/null +++ b/src/token/id.rs @@ -0,0 +1,27 @@ +use std::fmt; + +use crate::id::Id as BaseId; + +// Stable identifier for a token. Prefixed with `T`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From<BaseId> for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("T") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/token/mod.rs b/src/token/mod.rs new file mode 100644 index 0000000..d122611 --- /dev/null +++ b/src/token/mod.rs @@ -0,0 +1,9 @@ +pub mod app; +pub mod broadcaster; +mod event; +pub mod extract; +mod id; +mod repo; +mod secret; + +pub use self::{id::Id, secret::Secret}; diff --git a/src/login/repo/auth.rs b/src/token/repo/auth.rs index 3033c8f..b299697 100644 --- a/src/login/repo/auth.rs +++ b/src/token/repo/auth.rs @@ -1,9 +1,6 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use crate::{ - password::StoredHash, - repo::login::{self, Login}, -}; +use crate::login::{self, password::StoredHash, Login}; pub trait Provider { fn auth(&mut self) -> Auth; diff --git a/src/token/repo/mod.rs b/src/token/repo/mod.rs new file mode 100644 index 0000000..9169743 --- /dev/null +++ b/src/token/repo/mod.rs @@ -0,0 +1,4 @@ +pub mod auth; +mod token; + +pub use self::token::Provider; diff --git a/src/repo/token.rs b/src/token/repo/token.rs index d96c094..5f64dac 100644 --- a/src/repo/token.rs +++ b/src/token/repo/token.rs @@ -1,10 +1,11 @@ -use std::fmt; - use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use uuid::Uuid; -use super::login::{self, Login}; -use crate::{clock::DateTime, id::Id as BaseId, login::extract::IdentitySecret}; +use crate::{ + clock::DateTime, + login::{self, Login}, + token::{Id, Secret}, +}; pub trait Provider { fn tokens(&mut self) -> Tokens; @@ -25,7 +26,7 @@ impl<'c> Tokens<'c> { &mut self, login: &Login, issued_at: &DateTime, - ) -> Result<IdentitySecret, sqlx::Error> { + ) -> Result<Secret, sqlx::Error> { let id = Id::generate(); let secret = Uuid::new_v4().to_string(); @@ -34,7 +35,7 @@ impl<'c> Tokens<'c> { insert into token (id, secret, login, issued_at, last_used_at) values ($1, $2, $3, $4, $4) - returning secret as "secret!: IdentitySecret" + returning secret as "secret!: Secret" "#, id, secret, @@ -47,6 +48,21 @@ impl<'c> Tokens<'c> { Ok(secret) } + pub async fn require(&mut self, token: &Id) -> Result<(), sqlx::Error> { + sqlx::query_scalar!( + r#" + select id as "id: Id" + from token + where id = $1 + "#, + token, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(()) + } + // Revoke a token by its secret. pub async fn revoke(&mut self, token: &Id) -> Result<(), sqlx::Error> { sqlx::query_scalar!( @@ -87,7 +103,7 @@ impl<'c> Tokens<'c> { // timestamp will be set to `used_at`. pub async fn validate( &mut self, - secret: &IdentitySecret, + secret: &Secret, used_at: &DateTime, ) -> Result<(Id, Login), sqlx::Error> { // I would use `update … returning` to do this in one query, but @@ -133,27 +149,3 @@ impl<'c> Tokens<'c> { Ok(login) } } - -// Stable identifier for a token. Prefixed with `T`. -#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] -#[sqlx(transparent)] -#[serde(transparent)] -pub struct Id(BaseId); - -impl From<BaseId> for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("T") - } -} - -impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} diff --git a/src/token/secret.rs b/src/token/secret.rs new file mode 100644 index 0000000..28c93bb --- /dev/null +++ b/src/token/secret.rs @@ -0,0 +1,27 @@ +use std::fmt; + +#[derive(sqlx::Type)] +#[sqlx(transparent)] +pub struct Secret(String); + +impl Secret { + pub fn reveal(self) -> String { + let Self(secret) = self; + secret + } +} + +impl fmt::Debug for Secret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IdentityToken").field(&"********").finish() + } +} + +impl<S> From<S> for Secret +where + S: Into<String>, +{ + fn from(value: S) -> Self { + Self(value.into()) + } +} |
