diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/app.rs | 18 | ||||
| -rw-r--r-- | src/broadcast.rs | 78 | ||||
| -rw-r--r-- | src/channel/app.rs | 10 | ||||
| -rw-r--r-- | src/events/app.rs | 12 | ||||
| -rw-r--r-- | src/events/broadcaster.rs | 75 | ||||
| -rw-r--r-- | src/events/routes.rs | 5 | ||||
| -rw-r--r-- | src/events/routes/test.rs | 57 | ||||
| -rw-r--r-- | src/lib.rs | 1 | ||||
| -rw-r--r-- | src/login/app.rs | 58 | ||||
| -rw-r--r-- | src/login/broadcaster.rs | 3 | ||||
| -rw-r--r-- | src/login/extract.rs | 74 | ||||
| -rw-r--r-- | src/login/mod.rs | 2 | ||||
| -rw-r--r-- | src/login/routes/test/login.rs | 4 | ||||
| -rw-r--r-- | src/login/types.rs | 12 | ||||
| -rw-r--r-- | src/repo/login/extract.rs | 62 | ||||
| -rw-r--r-- | src/repo/token.rs | 70 | ||||
| -rw-r--r-- | src/test/fixtures/identity.rs | 16 |
17 files changed, 370 insertions, 187 deletions
@@ -2,33 +2,35 @@ use sqlx::sqlite::SqlitePool; use crate::{ channel::app::Channels, - events::{app::Events, broadcaster::Broadcaster}, - login::app::Logins, + events::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, + login::{app::Logins, broadcaster::Broadcaster as LoginBroadcaster}, }; #[derive(Clone)] pub struct App { db: SqlitePool, - broadcaster: Broadcaster, + events: EventBroadcaster, + logins: LoginBroadcaster, } impl App { pub fn from(db: SqlitePool) -> Self { - let broadcaster = Broadcaster::default(); - Self { db, broadcaster } + let events = EventBroadcaster::default(); + let logins = LoginBroadcaster::default(); + Self { db, events, logins } } } impl App { pub const fn logins(&self) -> Logins { - Logins::new(&self.db) + Logins::new(&self.db, &self.logins) } pub const fn events(&self) -> Events { - Events::new(&self.db, &self.broadcaster) + Events::new(&self.db, &self.events) } pub const fn channels(&self) -> Channels { - Channels::new(&self.db, &self.broadcaster) + Channels::new(&self.db, &self.events) } } diff --git a/src/broadcast.rs b/src/broadcast.rs new file mode 100644 index 0000000..083a301 --- /dev/null +++ b/src/broadcast.rs @@ -0,0 +1,78 @@ +use std::sync::{Arc, Mutex}; + +use futures::{future, stream::StreamExt as _, Stream}; +use tokio::sync::broadcast::{channel, Sender}; +use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; + +// Clones will share the same sender. +#[derive(Clone)] +pub struct Broadcaster<M> { + // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's + // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that + // lock it must be sync. + senders: Arc<Mutex<Sender<M>>>, +} + +impl<M> Default for Broadcaster<M> +where + M: Clone + Send + std::fmt::Debug + 'static, +{ + fn default() -> Self { + let sender = Self::make_sender(); + + Self { + senders: Arc::new(Mutex::new(sender)), + } + } +} + +impl<M> Broadcaster<M> +where + M: Clone + Send + std::fmt::Debug + 'static, +{ + // panic: if ``message.channel.id`` has not been previously registered, + // and was not part of the initial set of channels. + pub fn broadcast(&self, message: &M) { + let tx = self.sender(); + + // Per the Tokio docs, the returned error is only used to indicate that + // there are no receivers. In this use case, that's fine; a lack of + // listening consumers (chat clients) when a message is sent isn't an + // error. + // + // The successful return value, which includes the number of active + // receivers, also isn't that interesting to us. + let _ = tx.send(message.clone()); + } + + // panic: if ``channel`` has not been previously registered, and was not + // part of the initial set of channels. + pub fn subscribe(&self) -> impl Stream<Item = M> + std::fmt::Debug { + let rx = self.sender().subscribe(); + + BroadcastStream::from(rx).scan((), |(), r| { + future::ready(match r { + Ok(event) => Some(event), + // Stop the stream here. This will disconnect SSE clients + // (see `routes.rs`), who will then resume from + // `Last-Event-ID`, allowing them to catch up by reading + // the skipped messages from the database. + // + // See also: + // <https://users.rust-lang.org/t/taking-from-stream-while-ok/48854> + Err(BroadcastStreamRecvError::Lagged(_)) => None, + }) + }) + } + + fn sender(&self) -> Sender<M> { + self.senders.lock().unwrap().clone() + } + + fn make_sender() -> Sender<M> { + // Queue depth of 16 chosen entirely arbitrarily. Don't read too much + // into it. + let (tx, _) = channel(16); + tx + } +} diff --git a/src/channel/app.rs b/src/channel/app.rs index d7312e4..70cda47 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -9,12 +9,12 @@ use crate::{ pub struct Channels<'a> { db: &'a SqlitePool, - broadcaster: &'a Broadcaster, + events: &'a Broadcaster, } impl<'a> Channels<'a> { - pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self { - Self { db, broadcaster } + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } } pub async fn create(&self, name: &str, created_at: &DateTime) -> Result<Channel, CreateError> { @@ -26,7 +26,7 @@ impl<'a> Channels<'a> { .map_err(|err| CreateError::from_duplicate_name(err, name))?; tx.commit().await?; - self.broadcaster + self.events .broadcast(&ChannelEvent::created(channel.clone())); Ok(channel) @@ -60,7 +60,7 @@ impl<'a> Channels<'a> { tx.commit().await?; for event in events { - self.broadcaster.broadcast(&event); + self.events.broadcast(&event); } Ok(()) diff --git a/src/events/app.rs b/src/events/app.rs index 0cdc641..db7f430 100644 --- a/src/events/app.rs +++ b/src/events/app.rs @@ -24,12 +24,12 @@ use crate::{ pub struct Events<'a> { db: &'a SqlitePool, - broadcaster: &'a Broadcaster, + events: &'a Broadcaster, } impl<'a> Events<'a> { - pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self { - Self { db, broadcaster } + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } } pub async fn send( @@ -51,7 +51,7 @@ impl<'a> Events<'a> { .await?; tx.commit().await?; - self.broadcaster.broadcast(&event); + self.events.broadcast(&event); Ok(event) } @@ -75,7 +75,7 @@ impl<'a> Events<'a> { tx.commit().await?; for event in events { - self.broadcaster.broadcast(&event); + self.events.broadcast(&event); } Ok(()) @@ -101,7 +101,7 @@ impl<'a> Events<'a> { // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. - let live_messages = self.broadcaster.subscribe(); + let live_messages = self.events.subscribe(); let mut replays = BTreeMap::new(); let mut resume_live_at = resume_at.clone(); diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs index 9697c0a..6b664cb 100644 --- a/src/events/broadcaster.rs +++ b/src/events/broadcaster.rs @@ -1,74 +1,3 @@ -use std::sync::{Arc, Mutex}; +use crate::{broadcast, events::types}; -use futures::{future, stream::StreamExt as _, Stream}; -use tokio::sync::broadcast::{channel, Sender}; -use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; - -use crate::events::types; - -// Clones will share the same sender. -#[derive(Clone)] -pub struct Broadcaster { - // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's - // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that - // lock it must be sync. - senders: Arc<Mutex<Sender<types::ChannelEvent>>>, -} - -impl Default for Broadcaster { - fn default() -> Self { - let sender = Self::make_sender(); - - Self { - senders: Arc::new(Mutex::new(sender)), - } - } -} - -impl Broadcaster { - // panic: if ``message.channel.id`` has not been previously registered, - // and was not part of the initial set of channels. - pub fn broadcast(&self, message: &types::ChannelEvent) { - let tx = self.sender(); - - // Per the Tokio docs, the returned error is only used to indicate that - // there are no receivers. In this use case, that's fine; a lack of - // listening consumers (chat clients) when a message is sent isn't an - // error. - // - // The successful return value, which includes the number of active - // receivers, also isn't that interesting to us. - let _ = tx.send(message.clone()); - } - - // panic: if ``channel`` has not been previously registered, and was not - // part of the initial set of channels. - pub fn subscribe(&self) -> impl Stream<Item = types::ChannelEvent> + std::fmt::Debug { - let rx = self.sender().subscribe(); - - BroadcastStream::from(rx).scan((), |(), r| { - future::ready(match r { - Ok(event) => Some(event), - // Stop the stream here. This will disconnect SSE clients - // (see `routes.rs`), who will then resume from - // `Last-Event-ID`, allowing them to catch up by reading - // the skipped messages from the database. - // - // See also: - // <https://users.rust-lang.org/t/taking-from-stream-while-ok/48854> - Err(BroadcastStreamRecvError::Lagged(_)) => None, - }) - }) - } - - fn sender(&self) -> Sender<types::ChannelEvent> { - self.senders.lock().unwrap().clone() - } - - fn make_sender() -> Sender<types::ChannelEvent> { - // Queue depth of 16 chosen entirely arbitrarily. Don't read too much - // into it. - let (tx, _) = channel(16); - tx - } -} +pub type Broadcaster = broadcast::Broadcaster<types::ChannelEvent>; diff --git a/src/events/routes.rs b/src/events/routes.rs index 89c942c..ec9dae2 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -13,7 +13,7 @@ use super::{ extract::LastEventId, types::{self, ResumePoint}, }; -use crate::{app::App, error::Internal, repo::login::Login}; +use crate::{app::App, error::Internal, login::extract::Identity}; #[cfg(test)] mod test; @@ -24,7 +24,7 @@ pub fn router() -> Router<App> { async fn events( State(app): State<App>, - _: Login, // requires auth, but doesn't actually care who you are + identity: Identity, last_event_id: Option<LastEventId<ResumePoint>>, ) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> { let resume_at = last_event_id @@ -32,6 +32,7 @@ async fn events( .unwrap_or_default(); let stream = app.events().subscribe(resume_at).await?; + let stream = app.logins().limit_stream(identity.token, stream); Ok(Events(stream)) } diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs index a6e2275..0b62b5b 100644 --- a/src/events/routes/test.rs +++ b/src/events/routes/test.rs @@ -20,7 +20,8 @@ async fn includes_historical_message() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let routes::Events(events) = routes::events(State(app), subscriber, None) .await .expect("subscribe never fails"); @@ -46,7 +47,8 @@ async fn includes_live_message() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) .await .expect("subscribe never fails"); @@ -90,7 +92,8 @@ async fn includes_multiple_channels() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let routes::Events(events) = routes::events(State(app), subscriber, None) .await .expect("subscribe never fails"); @@ -127,7 +130,8 @@ async fn sequential_messages() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let routes::Events(events) = routes::events(State(app), subscriber, None) .await .expect("subscribe never fails"); @@ -166,7 +170,8 @@ async fn resumes_from() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let resume_at = { // First subscription @@ -232,7 +237,8 @@ async fn serial_resume() { // Call the endpoint - let subscriber = fixtures::login::create(&app).await; + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; let resume_at = { let initial_messages = [ @@ -335,3 +341,42 @@ async fn serial_resume() { } }; } + +#[tokio::test] +async fn terminates_on_token_expiry() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + // Subscribe via the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = + fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; + let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) + .await + .expect("subscribe never fails"); + + // Verify the resulting stream's behaviour + + app.logins() + .expire(&fixtures::now()) + .await + .expect("expiring tokens succeeds"); + + // These should not be delivered. + let messages = [ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + assert!(events + .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))) + .next() + .immediately() + .await + .is_none()); +} @@ -3,6 +3,7 @@ #![warn(clippy::pedantic)] mod app; +mod broadcast; mod channel; pub mod cli; mod clock; diff --git a/src/login/app.rs b/src/login/app.rs index f7fec88..b8916a8 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -1,24 +1,30 @@ use chrono::TimeDelta; +use futures::{ + future, + stream::{self, StreamExt as _}, + Stream, +}; use sqlx::sqlite::SqlitePool; -use super::{extract::IdentitySecret, repo::auth::Provider as _}; +use super::{broadcaster::Broadcaster, extract::IdentitySecret, repo::auth::Provider as _, types}; use crate::{ clock::DateTime, password::Password, repo::{ error::NotFound as _, login::{Login, Provider as _}, - token::Provider as _, + token::{self, Provider as _}, }, }; pub struct Logins<'a> { db: &'a SqlitePool, + logins: &'a Broadcaster, } impl<'a> Logins<'a> { - pub const fn new(db: &'a SqlitePool) -> Self { - Self { db } + pub const fn new(db: &'a SqlitePool, logins: &'a Broadcaster) -> Self { + Self { db, logins } } pub async fn login( @@ -63,7 +69,7 @@ impl<'a> Logins<'a> { &self, secret: &IdentitySecret, used_at: &DateTime, - ) -> Result<Login, ValidateError> { + ) -> Result<(token::Id, Login), ValidateError> { let mut tx = self.db.begin().await?; let login = tx .tokens() @@ -75,26 +81,56 @@ impl<'a> Logins<'a> { Ok(login) } + pub fn limit_stream<E>( + &self, + token: token::Id, + events: impl Stream<Item = E> + std::fmt::Debug, + ) -> impl Stream<Item = E> + std::fmt::Debug + where + E: std::fmt::Debug, + { + let token_events = self + .logins + .subscribe() + .filter(move |event| future::ready(event.token == token)) + .map(|_| GuardedEvent::TokenRevoked); + + let events = events.map(|event| GuardedEvent::Event(event)); + + stream::select(token_events, events).scan((), |(), event| { + future::ready(match event { + GuardedEvent::Event(event) => Some(event), + GuardedEvent::TokenRevoked => None, + }) + }) + } + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { // Somewhat arbitrarily, expire after 7 days. let expire_at = relative_to.to_owned() - TimeDelta::days(7); let mut tx = self.db.begin().await?; - tx.tokens().expire(&expire_at).await?; + let tokens = tx.tokens().expire(&expire_at).await?; tx.commit().await?; + for event in tokens.into_iter().map(types::TokenRevoked::from) { + self.logins.broadcast(&event); + } + Ok(()) } pub async fn logout(&self, secret: &IdentitySecret) -> Result<(), ValidateError> { let mut tx = self.db.begin().await?; - tx.tokens() + let token = tx + .tokens() .revoke(secret) .await .not_found(|| ValidateError::InvalidToken)?; - tx.commit().await?; + self.logins.broadcast(&types::TokenRevoked::from(token)); + Ok(()) } } @@ -124,3 +160,9 @@ pub enum ValidateError { #[error(transparent)] DatabaseError(#[from] sqlx::Error), } + +#[derive(Debug)] +enum GuardedEvent<E> { + TokenRevoked, + Event(E), +} diff --git a/src/login/broadcaster.rs b/src/login/broadcaster.rs new file mode 100644 index 0000000..8e1fb3a --- /dev/null +++ b/src/login/broadcaster.rs @@ -0,0 +1,3 @@ +use crate::{broadcast, login::types}; + +pub type Broadcaster = broadcast::Broadcaster<types::TokenRevoked>; diff --git a/src/login/extract.rs b/src/login/extract.rs index 3b31d4c..b585565 100644 --- a/src/login/extract.rs +++ b/src/login/extract.rs @@ -1,12 +1,20 @@ use std::fmt; use axum::{ - extract::FromRequestParts, - http::request::Parts, - response::{IntoResponseParts, ResponseParts}, + extract::{FromRequestParts, State}, + http::{request::Parts, StatusCode}, + response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use axum_extra::extract::cookie::{Cookie, CookieJar}; +use crate::{ + app::App, + clock::RequestedAt, + error::Internal, + login::app::ValidateError, + repo::{login::Login, token}, +}; + // The usage pattern here - receive the extractor as an argument, return it in // the response - is heavily modelled after CookieJar's own intended usage. #[derive(Clone)] @@ -112,3 +120,63 @@ where Self(value.into()) } } + +#[derive(Clone, Debug)] +pub struct Identity { + pub token: token::Id, + pub login: Login, +} + +#[async_trait::async_trait] +impl FromRequestParts<App> for Identity { + type Rejection = LoginError<Internal>; + + async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> { + // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on + // stable), the following can be replaced: + // + // ``` + // let Ok(identity_token) = IdentityToken::from_request_parts( + // parts, + // state, + // ).await; + // ``` + let identity_token = IdentityToken::from_request_parts(parts, state).await?; + let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?; + + let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; + + let app = State::<App>::from_request_parts(parts, state).await?; + match app.logins().validate(&secret, &used_at).await { + Ok((token, login)) => Ok(Identity { token, login }), + Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), + Err(other) => Err(other.into()), + } + } +} + +pub enum LoginError<E> { + Failure(E), + Unauthorized, +} + +impl<E> IntoResponse for LoginError<E> +where + E: IntoResponse, +{ + fn into_response(self) -> Response { + match self { + Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(), + Self::Failure(e) => e.into_response(), + } + } +} + +impl<E> From<E> for LoginError<Internal> +where + E: Into<Internal>, +{ + fn from(err: E) -> Self { + Self::Failure(err.into()) + } +} diff --git a/src/login/mod.rs b/src/login/mod.rs index 191cce0..6ae82ac 100644 --- a/src/login/mod.rs +++ b/src/login/mod.rs @@ -1,6 +1,8 @@ pub use self::routes::router; pub mod app; +pub mod broadcaster; pub mod extract; mod repo; mod routes; +pub mod types; diff --git a/src/login/routes/test/login.rs b/src/login/routes/test/login.rs index 10c17d6..81653ff 100644 --- a/src/login/routes/test/login.rs +++ b/src/login/routes/test/login.rs @@ -36,7 +36,7 @@ async fn new_identity() { // Verify the semantics let validated_at = fixtures::now(); - let validated = app + let (_, validated) = app .logins() .validate(&secret, &validated_at) .await @@ -73,7 +73,7 @@ async fn existing_identity() { // Verify the semantics let validated_at = fixtures::now(); - let validated_login = app + let (_, validated_login) = app .logins() .validate(&secret, &validated_at) .await diff --git a/src/login/types.rs b/src/login/types.rs new file mode 100644 index 0000000..7c7cbf9 --- /dev/null +++ b/src/login/types.rs @@ -0,0 +1,12 @@ +use crate::repo::token; + +#[derive(Clone, Debug)] +pub struct TokenRevoked { + pub token: token::Id, +} + +impl From<token::Id> for TokenRevoked { + fn from(token: token::Id) -> Self { + Self { token } + } +} diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs index c127078..ab61106 100644 --- a/src/repo/login/extract.rs +++ b/src/repo/login/extract.rs @@ -1,67 +1,15 @@ -use axum::{ - extract::{FromRequestParts, State}, - http::{request::Parts, StatusCode}, - response::{IntoResponse, Response}, -}; +use axum::{extract::FromRequestParts, http::request::Parts}; use super::Login; -use crate::{ - app::App, - clock::RequestedAt, - error::Internal, - login::{app::ValidateError, extract::IdentityToken}, -}; +use crate::{app::App, login::extract::Identity}; #[async_trait::async_trait] impl FromRequestParts<App> for Login { - type Rejection = LoginError<Internal>; + type Rejection = <Identity as FromRequestParts<App>>::Rejection; async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> { - // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on - // stable), the following can be replaced: - // - // ``` - // let Ok(identity_token) = IdentityToken::from_request_parts( - // parts, - // state, - // ).await; - // ``` - let identity_token = IdentityToken::from_request_parts(parts, state).await?; - let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?; + let identity = Identity::from_request_parts(parts, state).await?; - let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; - - let app = State::<App>::from_request_parts(parts, state).await?; - match app.logins().validate(&secret, &used_at).await { - Ok(login) => Ok(login), - Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), - Err(other) => Err(other.into()), - } - } -} - -pub enum LoginError<E> { - Failure(E), - Unauthorized, -} - -impl<E> IntoResponse for LoginError<E> -where - E: IntoResponse, -{ - fn into_response(self) -> Response { - match self { - Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(), - Self::Failure(e) => e.into_response(), - } - } -} - -impl<E> From<E> for LoginError<Internal> -where - E: Into<Internal>, -{ - fn from(err: E) -> Self { - Self::Failure(err.into()) + Ok(identity.login) } } diff --git a/src/repo/token.rs b/src/repo/token.rs index 15eef48..5f39e1d 100644 --- a/src/repo/token.rs +++ b/src/repo/token.rs @@ -1,8 +1,10 @@ +use std::fmt; + use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use uuid::Uuid; use super::login::{self, Login}; -use crate::{clock::DateTime, login::extract::IdentitySecret}; +use crate::{clock::DateTime, id::Id as BaseId, login::extract::IdentitySecret}; pub trait Provider { fn tokens(&mut self) -> Tokens; @@ -24,15 +26,17 @@ impl<'c> Tokens<'c> { login: &Login, issued_at: &DateTime, ) -> Result<IdentitySecret, sqlx::Error> { + let id = Id::generate(); let secret = Uuid::new_v4().to_string(); let secret = sqlx::query_scalar!( r#" insert - into token (secret, login, issued_at, last_used_at) - values ($1, $2, $3, $3) + into token (id, secret, login, issued_at, last_used_at) + values ($1, $2, $3, $4, $4) returning secret as "secret!: IdentitySecret" "#, + id, secret, login.id, issued_at, @@ -44,37 +48,38 @@ impl<'c> Tokens<'c> { } // Revoke a token by its secret. - pub async fn revoke(&mut self, secret: &IdentitySecret) -> Result<(), sqlx::Error> { - sqlx::query!( + pub async fn revoke(&mut self, secret: &IdentitySecret) -> Result<Id, sqlx::Error> { + let token = sqlx::query_scalar!( r#" delete from token where secret = $1 - returning 1 as "found: u32" + returning id as "id: Id" "#, secret, ) .fetch_one(&mut *self.0) .await?; - Ok(()) + Ok(token) } // Expire and delete all tokens that haven't been used more recently than // `expire_at`. - pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> { - sqlx::query!( + pub async fn expire(&mut self, expire_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> { + let tokens = sqlx::query_scalar!( r#" delete from token where last_used_at < $1 + returning id as "id: Id" "#, expire_at, ) - .execute(&mut *self.0) + .fetch_all(&mut *self.0) .await?; - Ok(()) + Ok(tokens) } // Validate a token by its secret, retrieving the associated Login record. @@ -84,7 +89,7 @@ impl<'c> Tokens<'c> { &mut self, secret: &IdentitySecret, used_at: &DateTime, - ) -> Result<Login, sqlx::Error> { + ) -> Result<(Id, Login), sqlx::Error> { // I would use `update … returning` to do this in one query, but // sqlite3, as of this writing, does not allow an update's `returning` // clause to reference columns from tables joined into the update. Two @@ -101,21 +106,54 @@ impl<'c> Tokens<'c> { .execute(&mut *self.0) .await?; - let login = sqlx::query_as!( - Login, + let login = sqlx::query!( r#" select - login.id as "id: login::Id", - name + token.id as "token_id: Id", + login.id as "login_id: login::Id", + name as "login_name" from login join token on login.id = token.login where token.secret = $1 "#, secret, ) + .map(|row| { + ( + row.token_id, + Login { + id: row.login_id, + name: row.login_name, + }, + ) + }) .fetch_one(&mut *self.0) .await?; Ok(login) } } + +// Stable identifier for a token. Prefixed with `T`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From<BaseId> for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("T") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs index 69b5f4c..bdd7881 100644 --- a/src/test/fixtures/identity.rs +++ b/src/test/fixtures/identity.rs @@ -3,7 +3,7 @@ use uuid::Uuid; use crate::{ app::App, clock::RequestedAt, - login::extract::{IdentitySecret, IdentityToken}, + login::extract::{Identity, IdentitySecret, IdentityToken}, password::Password, }; @@ -22,6 +22,20 @@ pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt) IdentityToken::new().set(token) } +pub async fn identity(app: &App, login: &(String, Password), issued_at: &RequestedAt) -> Identity { + let secret = logged_in(app, login, issued_at) + .await + .secret() + .expect("successful login generates a secret"); + let (token, login) = app + .logins() + .validate(&secret, issued_at) + .await + .expect("always validates newly-issued secret"); + + Identity { token, login } +} + pub fn secret(identity: &IdentityToken) -> IdentitySecret { identity.secret().expect("identity contained a secret") } |
