summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2024-10-01 20:32:57 -0400
committerOwen Jacobson <owen@grimoire.ca>2024-10-01 20:32:57 -0400
commit7645411bcf7201e3a4927566da78080dc6a84ccf (patch)
tree2711922bfeab6dc8b6494e9b0976f3f051dff4a9
parent6c054c5b8d43a818ccfa9087960dc19b286e6bb7 (diff)
Prevent racing between `limit_stream` and logging out.
-rw-r--r--.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json20
-rw-r--r--src/error.rs8
-rw-r--r--src/events/routes.rs26
-rw-r--r--src/login/app.rs43
-rw-r--r--src/login/extract.rs6
-rw-r--r--src/login/routes.rs11
-rw-r--r--src/repo/token.rs15
7 files changed, 111 insertions, 18 deletions
diff --git a/.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json b/.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json
new file mode 100644
index 0000000..e07ad25
--- /dev/null
+++ b/.sqlx/query-cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c.json
@@ -0,0 +1,20 @@
+{
+ "db_name": "SQLite",
+ "query": "\n select id as \"id: Id\"\n from token\n where id = $1\n ",
+ "describe": {
+ "columns": [
+ {
+ "name": "id: Id",
+ "ordinal": 0,
+ "type_info": "Text"
+ }
+ ],
+ "parameters": {
+ "Right": 1
+ },
+ "nullable": [
+ false
+ ]
+ },
+ "hash": "cd1d5a52fad0c2f6a9eaa489c4147c994df46347a9ce2030ae04a52ccfc0c40c"
+}
diff --git a/src/error.rs b/src/error.rs
index 6e797b4..8792a1d 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -61,3 +61,11 @@ impl fmt::Display for Id {
self.0.fmt(f)
}
}
+
+pub struct Unauthorized;
+
+impl IntoResponse for Unauthorized {
+ fn into_response(self) -> Response {
+ (StatusCode::UNAUTHORIZED, "unauthorized").into_response()
+ }
+}
diff --git a/src/events/routes.rs b/src/events/routes.rs
index ec9dae2..f09474c 100644
--- a/src/events/routes.rs
+++ b/src/events/routes.rs
@@ -13,7 +13,11 @@ use super::{
extract::LastEventId,
types::{self, ResumePoint},
};
-use crate::{app::App, error::Internal, login::extract::Identity};
+use crate::{
+ app::App,
+ error::{Internal, Unauthorized},
+ login::{app::ValidateError, extract::Identity},
+};
#[cfg(test)]
mod test;
@@ -26,13 +30,13 @@ async fn events(
State(app): State<App>,
identity: Identity,
last_event_id: Option<LastEventId<ResumePoint>>,
-) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> {
+) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, EventsError> {
let resume_at = last_event_id
.map(LastEventId::into_inner)
.unwrap_or_default();
let stream = app.events().subscribe(resume_at).await?;
- let stream = app.logins().limit_stream(identity.token, stream);
+ let stream = app.logins().limit_stream(identity.token, stream).await?;
Ok(Events(stream))
}
@@ -67,3 +71,19 @@ impl TryFrom<types::ResumableEvent> for sse::Event {
Ok(event)
}
}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum EventsError {
+ DatabaseError(#[from] sqlx::Error),
+ ValidateError(#[from] ValidateError),
+}
+
+impl IntoResponse for EventsError {
+ fn into_response(self) -> Response {
+ match self {
+ Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(),
+ other => Internal::from(other).into_response(),
+ }
+ }
+}
diff --git a/src/login/app.rs b/src/login/app.rs
index 182c62c..95f0a07 100644
--- a/src/login/app.rs
+++ b/src/login/app.rs
@@ -81,28 +81,55 @@ impl<'a> Logins<'a> {
Ok(login)
}
- pub fn limit_stream<E>(
+ pub async fn limit_stream<E>(
&self,
token: token::Id,
events: impl Stream<Item = E> + std::fmt::Debug,
- ) -> impl Stream<Item = E> + std::fmt::Debug
+ ) -> Result<impl Stream<Item = E> + std::fmt::Debug, ValidateError>
where
E: std::fmt::Debug,
{
- let token_events = self
- .logins
- .subscribe()
+ // Subscribe, first.
+ let token_events = self.logins.subscribe();
+
+ // Check that the token is valid at this point in time, second. If it is, then
+ // any future revocations will appear in the subscription. If not, bail now.
+ //
+ // It's possible, otherwise, to get to this point with a token that _was_ valid
+ // at the start of the request, but which was invalided _before_ the
+ // `subscribe()` call. In that case, the corresponding revocation event will
+ // simply be missed, since the `token_events` stream subscribed after the fact.
+ // This check cancels guarding the stream here.
+ //
+ // Yes, this is a weird niche edge case. Most things don't double-check, because
+ // they aren't expected to run long enough for the token's revocation to
+ // matter. Supervising a stream, on the other hand, will run for a
+ // _long_ time; if we miss the race here, we'll never actually carry out the
+ // supervision.
+ let mut tx = self.db.begin().await?;
+ tx.tokens()
+ .require(&token)
+ .await
+ .not_found(|| ValidateError::InvalidToken)?;
+ tx.commit().await?;
+
+ // Then construct the guarded stream. First, project both streams into
+ // `GuardedEvent`.
+ let token_events = token_events
.filter(move |event| future::ready(event.token == token))
.map(|_| GuardedEvent::TokenRevoked);
-
let events = events.map(|event| GuardedEvent::Event(event));
- stream::select(token_events, events).scan((), |(), event| {
+ // Merge the two streams, then unproject them, stopping at
+ // `GuardedEvent::TokenRevoked`.
+ let stream = stream::select(token_events, events).scan((), |(), event| {
future::ready(match event {
GuardedEvent::Event(event) => Some(event),
GuardedEvent::TokenRevoked => None,
})
- })
+ });
+
+ Ok(stream)
}
pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> {
diff --git a/src/login/extract.rs b/src/login/extract.rs
index b585565..bfdbe8d 100644
--- a/src/login/extract.rs
+++ b/src/login/extract.rs
@@ -2,7 +2,7 @@ use std::fmt;
use axum::{
extract::{FromRequestParts, State},
- http::{request::Parts, StatusCode},
+ http::request::Parts,
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use axum_extra::extract::cookie::{Cookie, CookieJar};
@@ -10,7 +10,7 @@ use axum_extra::extract::cookie::{Cookie, CookieJar};
use crate::{
app::App,
clock::RequestedAt,
- error::Internal,
+ error::{Internal, Unauthorized},
login::app::ValidateError,
repo::{login::Login, token},
};
@@ -166,7 +166,7 @@ where
{
fn into_response(self) -> Response {
match self {
- Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(),
+ Self::Unauthorized => Unauthorized.into_response(),
Self::Failure(e) => e.into_response(),
}
}
diff --git a/src/login/routes.rs b/src/login/routes.rs
index 8d9e938..d7cb9b1 100644
--- a/src/login/routes.rs
+++ b/src/login/routes.rs
@@ -7,7 +7,11 @@ use axum::{
};
use crate::{
- app::App, clock::RequestedAt, error::Internal, password::Password, repo::login::Login,
+ app::App,
+ clock::RequestedAt,
+ error::{Internal, Unauthorized},
+ password::Password,
+ repo::login::Login,
};
use super::{app, extract::IdentityToken};
@@ -66,6 +70,7 @@ impl IntoResponse for LoginError {
let Self(error) = self;
match error {
app::LoginError::Rejected => {
+ // not error::Unauthorized due to differing messaging
(StatusCode::UNAUTHORIZED, "invalid name or password").into_response()
}
other => Internal::from(other).into_response(),
@@ -103,9 +108,7 @@ enum LogoutError {
impl IntoResponse for LogoutError {
fn into_response(self) -> Response {
match self {
- error @ Self::ValidateError(app::ValidateError::InvalidToken) => {
- (StatusCode::UNAUTHORIZED, error.to_string()).into_response()
- }
+ Self::ValidateError(app::ValidateError::InvalidToken) => Unauthorized.into_response(),
other => Internal::from(other).into_response(),
}
}
diff --git a/src/repo/token.rs b/src/repo/token.rs
index d96c094..1663f5e 100644
--- a/src/repo/token.rs
+++ b/src/repo/token.rs
@@ -47,6 +47,21 @@ impl<'c> Tokens<'c> {
Ok(secret)
}
+ pub async fn require(&mut self, token: &Id) -> Result<(), sqlx::Error> {
+ sqlx::query_scalar!(
+ r#"
+ select id as "id: Id"
+ from token
+ where id = $1
+ "#,
+ token,
+ )
+ .fetch_one(&mut *self.0)
+ .await?;
+
+ Ok(())
+ }
+
// Revoke a token by its secret.
pub async fn revoke(&mut self, token: &Id) -> Result<(), sqlx::Error> {
sqlx::query_scalar!(