diff options
| -rw-r--r-- | .sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json | 20 | ||||
| -rw-r--r-- | src/error.rs | 8 | ||||
| -rw-r--r-- | src/events/routes.rs | 26 | ||||
| -rw-r--r-- | src/login/app.rs | 43 | ||||
| -rw-r--r-- | src/login/extract.rs | 6 | ||||
| -rw-r--r-- | src/login/routes.rs | 11 | ||||
| -rw-r--r-- | src/repo/token.rs | 15 |
7 files changed, 111 insertions, 18 deletions
diff --git a/.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json b/.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json new file mode 100644 index 0000000..e07ad25 --- /dev/null +++ b/.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n select id as \"id: Id\"\n from token\n where id = $1\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false + ] + }, + "hash": "cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c" +} diff --git a/src/error.rs b/src/error.rs index 6e797b4..8792a1d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -61,3 +61,11 @@ impl fmt::Display for Id { self.0.fmt(f) } } + +pub struct Unauthorized; + +impl IntoResponse for Unauthorized { + fn into_response(self) -> Response { + (StatusCode::UNAUTHORIZED, "unauthorized").into_response() + } +} diff --git a/src/events/routes.rs b/src/events/routes.rs index ec9dae2..f09474c 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -13,7 +13,11 @@ use super::{ extract::LastEventId, types::{self, ResumePoint}, }; -use crate::{app::App, error::Internal, login::extract::Identity}; +use crate::{ + app::App, + error::{Internal, Unauthorized}, + login::{app::ValidateError, extract::Identity}, +}; #[cfg(test)] mod test; @@ -26,13 +30,13 @@ async fn events( State(app): State<App>, identity: Identity, last_event_id: Option<LastEventId<ResumePoint>>, -) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> { +) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, EventsError> { let resume_at = last_event_id .map(LastEventId::into_inner) .unwrap_or_default(); let stream = app.events().subscribe(resume_at).await?; - let stream = app.logins().limit_stream(identity.token, stream); + let stream = app.logins().limit_stream(identity.token, stream).await?; Ok(Events(stream)) } @@ -67,3 +71,19 @@ impl TryFrom<types::ResumableEvent> for sse::Event { Ok(event) } } + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum EventsError { + DatabaseError(#[from] sqlx::Error), + ValidateError(#[from] ValidateError), +} + +impl IntoResponse for EventsError { + fn into_response(self) -> Response { + match self { + Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(), + other => Internal::from(other).into_response(), + } + } +} diff --git a/src/login/app.rs b/src/login/app.rs index 182c62c..95f0a07 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -81,28 +81,55 @@ impl<'a> Logins<'a> { Ok(login) } - pub fn limit_stream<E>( + pub async fn limit_stream<E>( &self, token: token::Id, events: impl Stream<Item = E> + std::fmt::Debug, - ) -> impl Stream<Item = E> + std::fmt::Debug + ) -> Result<impl Stream<Item = E> + std::fmt::Debug, ValidateError> where E: std::fmt::Debug, { - let token_events = self - .logins - .subscribe() + // Subscribe, first. + let token_events = self.logins.subscribe(); + + // Check that the token is valid at this point in time, second. If it is, then + // any future revocations will appear in the subscription. If not, bail now. + // + // It's possible, otherwise, to get to this point with a token that _was_ valid + // at the start of the request, but which was invalided _before_ the + // `subscribe()` call. In that case, the corresponding revocation event will + // simply be missed, since the `token_events` stream subscribed after the fact. + // This check cancels guarding the stream here. + // + // Yes, this is a weird niche edge case. Most things don't double-check, because + // they aren't expected to run long enough for the token's revocation to + // matter. Supervising a stream, on the other hand, will run for a + // _long_ time; if we miss the race here, we'll never actually carry out the + // supervision. + let mut tx = self.db.begin().await?; + tx.tokens() + .require(&token) + .await + .not_found(|| ValidateError::InvalidToken)?; + tx.commit().await?; + + // Then construct the guarded stream. First, project both streams into + // `GuardedEvent`. + let token_events = token_events .filter(move |event| future::ready(event.token == token)) .map(|_| GuardedEvent::TokenRevoked); - let events = events.map(|event| GuardedEvent::Event(event)); - stream::select(token_events, events).scan((), |(), event| { + // Merge the two streams, then unproject them, stopping at + // `GuardedEvent::TokenRevoked`. + let stream = stream::select(token_events, events).scan((), |(), event| { future::ready(match event { GuardedEvent::Event(event) => Some(event), GuardedEvent::TokenRevoked => None, }) - }) + }); + + Ok(stream) } pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { diff --git a/src/login/extract.rs b/src/login/extract.rs index b585565..bfdbe8d 100644 --- a/src/login/extract.rs +++ b/src/login/extract.rs @@ -2,7 +2,7 @@ use std::fmt; use axum::{ extract::{FromRequestParts, State}, - http::{request::Parts, StatusCode}, + http::request::Parts, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use axum_extra::extract::cookie::{Cookie, CookieJar}; @@ -10,7 +10,7 @@ use axum_extra::extract::cookie::{Cookie, CookieJar}; use crate::{ app::App, clock::RequestedAt, - error::Internal, + error::{Internal, Unauthorized}, login::app::ValidateError, repo::{login::Login, token}, }; @@ -166,7 +166,7 @@ where { fn into_response(self) -> Response { match self { - Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(), + Self::Unauthorized => Unauthorized.into_response(), Self::Failure(e) => e.into_response(), } } diff --git a/src/login/routes.rs b/src/login/routes.rs index 8d9e938..d7cb9b1 100644 --- a/src/login/routes.rs +++ b/src/login/routes.rs @@ -7,7 +7,11 @@ use axum::{ }; use crate::{ - app::App, clock::RequestedAt, error::Internal, password::Password, repo::login::Login, + app::App, + clock::RequestedAt, + error::{Internal, Unauthorized}, + password::Password, + repo::login::Login, }; use super::{app, extract::IdentityToken}; @@ -66,6 +70,7 @@ impl IntoResponse for LoginError { let Self(error) = self; match error { app::LoginError::Rejected => { + // not error::Unauthorized due to differing messaging (StatusCode::UNAUTHORIZED, "invalid name or password").into_response() } other => Internal::from(other).into_response(), @@ -103,9 +108,7 @@ enum LogoutError { impl IntoResponse for LogoutError { fn into_response(self) -> Response { match self { - error @ Self::ValidateError(app::ValidateError::InvalidToken) => { - (StatusCode::UNAUTHORIZED, error.to_string()).into_response() - } + Self::ValidateError(app::ValidateError::InvalidToken) => Unauthorized.into_response(), other => Internal::from(other).into_response(), } } diff --git a/src/repo/token.rs b/src/repo/token.rs index d96c094..1663f5e 100644 --- a/src/repo/token.rs +++ b/src/repo/token.rs @@ -47,6 +47,21 @@ impl<'c> Tokens<'c> { Ok(secret) } + pub async fn require(&mut self, token: &Id) -> Result<(), sqlx::Error> { + sqlx::query_scalar!( + r#" + select id as "id: Id" + from token + where id = $1 + "#, + token, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(()) + } + // Revoke a token by its secret. pub async fn revoke(&mut self, token: &Id) -> Result<(), sqlx::Error> { sqlx::query_scalar!( |
