summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2024-09-28 21:55:20 -0400
committerOwen Jacobson <owen@grimoire.ca>2024-09-29 01:19:19 -0400
commit0b1cb80dd0b0f90c4892de7e7a2d18a076ecbdf2 (patch)
treeb41313dbd92811ffcc87b0af576dc570b5802a1e /src
parent4d0bb0709b168a24ab6a8dbc86da45d7503596ee (diff)
Shut down the `/api/events` stream when the user logs out or their token expires.
When tokens are revoked (logout or expiry), the server now publishes an internal event via the new `logins` event broadcaster. These events are used to guard the `/api/events` stream. When a token revocation event arrives for the token used to subscribe to the stream, the stream is cut short, disconnecting the client. In service of this, tokens now have IDs, which are non-confidential values that can be used to discuss tokens without their secrets being passed around unnecessarily. These IDs are not (at this time) exposed to clients, but they could be.
Diffstat (limited to 'src')
-rw-r--r--src/app.rs18
-rw-r--r--src/broadcast.rs78
-rw-r--r--src/channel/app.rs10
-rw-r--r--src/events/app.rs12
-rw-r--r--src/events/broadcaster.rs75
-rw-r--r--src/events/routes.rs5
-rw-r--r--src/events/routes/test.rs57
-rw-r--r--src/lib.rs1
-rw-r--r--src/login/app.rs58
-rw-r--r--src/login/broadcaster.rs3
-rw-r--r--src/login/extract.rs74
-rw-r--r--src/login/mod.rs2
-rw-r--r--src/login/routes/test/login.rs4
-rw-r--r--src/login/types.rs12
-rw-r--r--src/repo/login/extract.rs62
-rw-r--r--src/repo/token.rs70
-rw-r--r--src/test/fixtures/identity.rs16
17 files changed, 370 insertions, 187 deletions
diff --git a/src/app.rs b/src/app.rs
index 245feb1..c13f52f 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -2,33 +2,35 @@ use sqlx::sqlite::SqlitePool;
use crate::{
channel::app::Channels,
- events::{app::Events, broadcaster::Broadcaster},
- login::app::Logins,
+ events::{app::Events, broadcaster::Broadcaster as EventBroadcaster},
+ login::{app::Logins, broadcaster::Broadcaster as LoginBroadcaster},
};
#[derive(Clone)]
pub struct App {
db: SqlitePool,
- broadcaster: Broadcaster,
+ events: EventBroadcaster,
+ logins: LoginBroadcaster,
}
impl App {
pub fn from(db: SqlitePool) -> Self {
- let broadcaster = Broadcaster::default();
- Self { db, broadcaster }
+ let events = EventBroadcaster::default();
+ let logins = LoginBroadcaster::default();
+ Self { db, events, logins }
}
}
impl App {
pub const fn logins(&self) -> Logins {
- Logins::new(&self.db)
+ Logins::new(&self.db, &self.logins)
}
pub const fn events(&self) -> Events {
- Events::new(&self.db, &self.broadcaster)
+ Events::new(&self.db, &self.events)
}
pub const fn channels(&self) -> Channels {
- Channels::new(&self.db, &self.broadcaster)
+ Channels::new(&self.db, &self.events)
}
}
diff --git a/src/broadcast.rs b/src/broadcast.rs
new file mode 100644
index 0000000..083a301
--- /dev/null
+++ b/src/broadcast.rs
@@ -0,0 +1,78 @@
+use std::sync::{Arc, Mutex};
+
+use futures::{future, stream::StreamExt as _, Stream};
+use tokio::sync::broadcast::{channel, Sender};
+use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
+
+// Clones will share the same sender.
+#[derive(Clone)]
+pub struct Broadcaster<M> {
+ // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
+ // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
+ // lock it must be sync.
+ senders: Arc<Mutex<Sender<M>>>,
+}
+
+impl<M> Default for Broadcaster<M>
+where
+ M: Clone + Send + std::fmt::Debug + 'static,
+{
+ fn default() -> Self {
+ let sender = Self::make_sender();
+
+ Self {
+ senders: Arc::new(Mutex::new(sender)),
+ }
+ }
+}
+
+impl<M> Broadcaster<M>
+where
+ M: Clone + Send + std::fmt::Debug + 'static,
+{
+ // panic: if ``message.channel.id`` has not been previously registered,
+ // and was not part of the initial set of channels.
+ pub fn broadcast(&self, message: &M) {
+ let tx = self.sender();
+
+ // Per the Tokio docs, the returned error is only used to indicate that
+ // there are no receivers. In this use case, that's fine; a lack of
+ // listening consumers (chat clients) when a message is sent isn't an
+ // error.
+ //
+ // The successful return value, which includes the number of active
+ // receivers, also isn't that interesting to us.
+ let _ = tx.send(message.clone());
+ }
+
+ // panic: if ``channel`` has not been previously registered, and was not
+ // part of the initial set of channels.
+ pub fn subscribe(&self) -> impl Stream<Item = M> + std::fmt::Debug {
+ let rx = self.sender().subscribe();
+
+ BroadcastStream::from(rx).scan((), |(), r| {
+ future::ready(match r {
+ Ok(event) => Some(event),
+ // Stop the stream here. This will disconnect SSE clients
+ // (see `routes.rs`), who will then resume from
+ // `Last-Event-ID`, allowing them to catch up by reading
+ // the skipped messages from the database.
+ //
+ // See also:
+ // <https://users.rust-lang.org/t/taking-from-stream-while-ok/48854>
+ Err(BroadcastStreamRecvError::Lagged(_)) => None,
+ })
+ })
+ }
+
+ fn sender(&self) -> Sender<M> {
+ self.senders.lock().unwrap().clone()
+ }
+
+ fn make_sender() -> Sender<M> {
+ // Queue depth of 16 chosen entirely arbitrarily. Don't read too much
+ // into it.
+ let (tx, _) = channel(16);
+ tx
+ }
+}
diff --git a/src/channel/app.rs b/src/channel/app.rs
index d7312e4..70cda47 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -9,12 +9,12 @@ use crate::{
pub struct Channels<'a> {
db: &'a SqlitePool,
- broadcaster: &'a Broadcaster,
+ events: &'a Broadcaster,
}
impl<'a> Channels<'a> {
- pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self {
- Self { db, broadcaster }
+ pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self {
+ Self { db, events }
}
pub async fn create(&self, name: &str, created_at: &DateTime) -> Result<Channel, CreateError> {
@@ -26,7 +26,7 @@ impl<'a> Channels<'a> {
.map_err(|err| CreateError::from_duplicate_name(err, name))?;
tx.commit().await?;
- self.broadcaster
+ self.events
.broadcast(&ChannelEvent::created(channel.clone()));
Ok(channel)
@@ -60,7 +60,7 @@ impl<'a> Channels<'a> {
tx.commit().await?;
for event in events {
- self.broadcaster.broadcast(&event);
+ self.events.broadcast(&event);
}
Ok(())
diff --git a/src/events/app.rs b/src/events/app.rs
index 0cdc641..db7f430 100644
--- a/src/events/app.rs
+++ b/src/events/app.rs
@@ -24,12 +24,12 @@ use crate::{
pub struct Events<'a> {
db: &'a SqlitePool,
- broadcaster: &'a Broadcaster,
+ events: &'a Broadcaster,
}
impl<'a> Events<'a> {
- pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self {
- Self { db, broadcaster }
+ pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self {
+ Self { db, events }
}
pub async fn send(
@@ -51,7 +51,7 @@ impl<'a> Events<'a> {
.await?;
tx.commit().await?;
- self.broadcaster.broadcast(&event);
+ self.events.broadcast(&event);
Ok(event)
}
@@ -75,7 +75,7 @@ impl<'a> Events<'a> {
tx.commit().await?;
for event in events {
- self.broadcaster.broadcast(&event);
+ self.events.broadcast(&event);
}
Ok(())
@@ -101,7 +101,7 @@ impl<'a> Events<'a> {
// Subscribe before retrieving, to catch messages broadcast while we're
// querying the DB. We'll prune out duplicates later.
- let live_messages = self.broadcaster.subscribe();
+ let live_messages = self.events.subscribe();
let mut replays = BTreeMap::new();
let mut resume_live_at = resume_at.clone();
diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs
index 9697c0a..6b664cb 100644
--- a/src/events/broadcaster.rs
+++ b/src/events/broadcaster.rs
@@ -1,74 +1,3 @@
-use std::sync::{Arc, Mutex};
+use crate::{broadcast, events::types};
-use futures::{future, stream::StreamExt as _, Stream};
-use tokio::sync::broadcast::{channel, Sender};
-use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
-
-use crate::events::types;
-
-// Clones will share the same sender.
-#[derive(Clone)]
-pub struct Broadcaster {
- // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
- // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
- // lock it must be sync.
- senders: Arc<Mutex<Sender<types::ChannelEvent>>>,
-}
-
-impl Default for Broadcaster {
- fn default() -> Self {
- let sender = Self::make_sender();
-
- Self {
- senders: Arc::new(Mutex::new(sender)),
- }
- }
-}
-
-impl Broadcaster {
- // panic: if ``message.channel.id`` has not been previously registered,
- // and was not part of the initial set of channels.
- pub fn broadcast(&self, message: &types::ChannelEvent) {
- let tx = self.sender();
-
- // Per the Tokio docs, the returned error is only used to indicate that
- // there are no receivers. In this use case, that's fine; a lack of
- // listening consumers (chat clients) when a message is sent isn't an
- // error.
- //
- // The successful return value, which includes the number of active
- // receivers, also isn't that interesting to us.
- let _ = tx.send(message.clone());
- }
-
- // panic: if ``channel`` has not been previously registered, and was not
- // part of the initial set of channels.
- pub fn subscribe(&self) -> impl Stream<Item = types::ChannelEvent> + std::fmt::Debug {
- let rx = self.sender().subscribe();
-
- BroadcastStream::from(rx).scan((), |(), r| {
- future::ready(match r {
- Ok(event) => Some(event),
- // Stop the stream here. This will disconnect SSE clients
- // (see `routes.rs`), who will then resume from
- // `Last-Event-ID`, allowing them to catch up by reading
- // the skipped messages from the database.
- //
- // See also:
- // <https://users.rust-lang.org/t/taking-from-stream-while-ok/48854>
- Err(BroadcastStreamRecvError::Lagged(_)) => None,
- })
- })
- }
-
- fn sender(&self) -> Sender<types::ChannelEvent> {
- self.senders.lock().unwrap().clone()
- }
-
- fn make_sender() -> Sender<types::ChannelEvent> {
- // Queue depth of 16 chosen entirely arbitrarily. Don't read too much
- // into it.
- let (tx, _) = channel(16);
- tx
- }
-}
+pub type Broadcaster = broadcast::Broadcaster<types::ChannelEvent>;
diff --git a/src/events/routes.rs b/src/events/routes.rs
index 89c942c..ec9dae2 100644
--- a/src/events/routes.rs
+++ b/src/events/routes.rs
@@ -13,7 +13,7 @@ use super::{
extract::LastEventId,
types::{self, ResumePoint},
};
-use crate::{app::App, error::Internal, repo::login::Login};
+use crate::{app::App, error::Internal, login::extract::Identity};
#[cfg(test)]
mod test;
@@ -24,7 +24,7 @@ pub fn router() -> Router<App> {
async fn events(
State(app): State<App>,
- _: Login, // requires auth, but doesn't actually care who you are
+ identity: Identity,
last_event_id: Option<LastEventId<ResumePoint>>,
) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> {
let resume_at = last_event_id
@@ -32,6 +32,7 @@ async fn events(
.unwrap_or_default();
let stream = app.events().subscribe(resume_at).await?;
+ let stream = app.logins().limit_stream(identity.token, stream);
Ok(Events(stream))
}
diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs
index a6e2275..0b62b5b 100644
--- a/src/events/routes/test.rs
+++ b/src/events/routes/test.rs
@@ -20,7 +20,8 @@ async fn includes_historical_message() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let routes::Events(events) = routes::events(State(app), subscriber, None)
.await
.expect("subscribe never fails");
@@ -46,7 +47,8 @@ async fn includes_live_message() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let routes::Events(events) = routes::events(State(app.clone()), subscriber, None)
.await
.expect("subscribe never fails");
@@ -90,7 +92,8 @@ async fn includes_multiple_channels() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let routes::Events(events) = routes::events(State(app), subscriber, None)
.await
.expect("subscribe never fails");
@@ -127,7 +130,8 @@ async fn sequential_messages() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let routes::Events(events) = routes::events(State(app), subscriber, None)
.await
.expect("subscribe never fails");
@@ -166,7 +170,8 @@ async fn resumes_from() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let resume_at = {
// First subscription
@@ -232,7 +237,8 @@ async fn serial_resume() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let resume_at = {
let initial_messages = [
@@ -335,3 +341,42 @@ async fn serial_resume() {
}
};
}
+
+#[tokio::test]
+async fn terminates_on_token_expiry() {
+ // Set up the environment
+
+ let app = fixtures::scratch_app().await;
+ let channel = fixtures::channel::create(&app, &fixtures::now()).await;
+ let sender = fixtures::login::create(&app).await;
+
+ // Subscribe via the endpoint
+
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber =
+ fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await;
+ let routes::Events(events) = routes::events(State(app.clone()), subscriber, None)
+ .await
+ .expect("subscribe never fails");
+
+ // Verify the resulting stream's behaviour
+
+ app.logins()
+ .expire(&fixtures::now())
+ .await
+ .expect("expiring tokens succeeds");
+
+ // These should not be delivered.
+ let messages = [
+ fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ ];
+
+ assert!(events
+ .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event)))
+ .next()
+ .immediately()
+ .await
+ .is_none());
+}
diff --git a/src/lib.rs b/src/lib.rs
index 4139d4d..271118b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,6 +3,7 @@
#![warn(clippy::pedantic)]
mod app;
+mod broadcast;
mod channel;
pub mod cli;
mod clock;
diff --git a/src/login/app.rs b/src/login/app.rs
index f7fec88..b8916a8 100644
--- a/src/login/app.rs
+++ b/src/login/app.rs
@@ -1,24 +1,30 @@
use chrono::TimeDelta;
+use futures::{
+ future,
+ stream::{self, StreamExt as _},
+ Stream,
+};
use sqlx::sqlite::SqlitePool;
-use super::{extract::IdentitySecret, repo::auth::Provider as _};
+use super::{broadcaster::Broadcaster, extract::IdentitySecret, repo::auth::Provider as _, types};
use crate::{
clock::DateTime,
password::Password,
repo::{
error::NotFound as _,
login::{Login, Provider as _},
- token::Provider as _,
+ token::{self, Provider as _},
},
};
pub struct Logins<'a> {
db: &'a SqlitePool,
+ logins: &'a Broadcaster,
}
impl<'a> Logins<'a> {
- pub const fn new(db: &'a SqlitePool) -> Self {
- Self { db }
+ pub const fn new(db: &'a SqlitePool, logins: &'a Broadcaster) -> Self {
+ Self { db, logins }
}
pub async fn login(
@@ -63,7 +69,7 @@ impl<'a> Logins<'a> {
&self,
secret: &IdentitySecret,
used_at: &DateTime,
- ) -> Result<Login, ValidateError> {
+ ) -> Result<(token::Id, Login), ValidateError> {
let mut tx = self.db.begin().await?;
let login = tx
.tokens()
@@ -75,26 +81,56 @@ impl<'a> Logins<'a> {
Ok(login)
}
+ pub fn limit_stream<E>(
+ &self,
+ token: token::Id,
+ events: impl Stream<Item = E> + std::fmt::Debug,
+ ) -> impl Stream<Item = E> + std::fmt::Debug
+ where
+ E: std::fmt::Debug,
+ {
+ let token_events = self
+ .logins
+ .subscribe()
+ .filter(move |event| future::ready(event.token == token))
+ .map(|_| GuardedEvent::TokenRevoked);
+
+ let events = events.map(|event| GuardedEvent::Event(event));
+
+ stream::select(token_events, events).scan((), |(), event| {
+ future::ready(match event {
+ GuardedEvent::Event(event) => Some(event),
+ GuardedEvent::TokenRevoked => None,
+ })
+ })
+ }
+
pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> {
// Somewhat arbitrarily, expire after 7 days.
let expire_at = relative_to.to_owned() - TimeDelta::days(7);
let mut tx = self.db.begin().await?;
- tx.tokens().expire(&expire_at).await?;
+ let tokens = tx.tokens().expire(&expire_at).await?;
tx.commit().await?;
+ for event in tokens.into_iter().map(types::TokenRevoked::from) {
+ self.logins.broadcast(&event);
+ }
+
Ok(())
}
pub async fn logout(&self, secret: &IdentitySecret) -> Result<(), ValidateError> {
let mut tx = self.db.begin().await?;
- tx.tokens()
+ let token = tx
+ .tokens()
.revoke(secret)
.await
.not_found(|| ValidateError::InvalidToken)?;
-
tx.commit().await?;
+ self.logins.broadcast(&types::TokenRevoked::from(token));
+
Ok(())
}
}
@@ -124,3 +160,9 @@ pub enum ValidateError {
#[error(transparent)]
DatabaseError(#[from] sqlx::Error),
}
+
+#[derive(Debug)]
+enum GuardedEvent<E> {
+ TokenRevoked,
+ Event(E),
+}
diff --git a/src/login/broadcaster.rs b/src/login/broadcaster.rs
new file mode 100644
index 0000000..8e1fb3a
--- /dev/null
+++ b/src/login/broadcaster.rs
@@ -0,0 +1,3 @@
+use crate::{broadcast, login::types};
+
+pub type Broadcaster = broadcast::Broadcaster<types::TokenRevoked>;
diff --git a/src/login/extract.rs b/src/login/extract.rs
index 3b31d4c..b585565 100644
--- a/src/login/extract.rs
+++ b/src/login/extract.rs
@@ -1,12 +1,20 @@
use std::fmt;
use axum::{
- extract::FromRequestParts,
- http::request::Parts,
- response::{IntoResponseParts, ResponseParts},
+ extract::{FromRequestParts, State},
+ http::{request::Parts, StatusCode},
+ response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use axum_extra::extract::cookie::{Cookie, CookieJar};
+use crate::{
+ app::App,
+ clock::RequestedAt,
+ error::Internal,
+ login::app::ValidateError,
+ repo::{login::Login, token},
+};
+
// The usage pattern here - receive the extractor as an argument, return it in
// the response - is heavily modelled after CookieJar's own intended usage.
#[derive(Clone)]
@@ -112,3 +120,63 @@ where
Self(value.into())
}
}
+
+#[derive(Clone, Debug)]
+pub struct Identity {
+ pub token: token::Id,
+ pub login: Login,
+}
+
+#[async_trait::async_trait]
+impl FromRequestParts<App> for Identity {
+ type Rejection = LoginError<Internal>;
+
+ async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> {
+ // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on
+ // stable), the following can be replaced:
+ //
+ // ```
+ // let Ok(identity_token) = IdentityToken::from_request_parts(
+ // parts,
+ // state,
+ // ).await;
+ // ```
+ let identity_token = IdentityToken::from_request_parts(parts, state).await?;
+ let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?;
+
+ let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?;
+
+ let app = State::<App>::from_request_parts(parts, state).await?;
+ match app.logins().validate(&secret, &used_at).await {
+ Ok((token, login)) => Ok(Identity { token, login }),
+ Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized),
+ Err(other) => Err(other.into()),
+ }
+ }
+}
+
+pub enum LoginError<E> {
+ Failure(E),
+ Unauthorized,
+}
+
+impl<E> IntoResponse for LoginError<E>
+where
+ E: IntoResponse,
+{
+ fn into_response(self) -> Response {
+ match self {
+ Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(),
+ Self::Failure(e) => e.into_response(),
+ }
+ }
+}
+
+impl<E> From<E> for LoginError<Internal>
+where
+ E: Into<Internal>,
+{
+ fn from(err: E) -> Self {
+ Self::Failure(err.into())
+ }
+}
diff --git a/src/login/mod.rs b/src/login/mod.rs
index 191cce0..6ae82ac 100644
--- a/src/login/mod.rs
+++ b/src/login/mod.rs
@@ -1,6 +1,8 @@
pub use self::routes::router;
pub mod app;
+pub mod broadcaster;
pub mod extract;
mod repo;
mod routes;
+pub mod types;
diff --git a/src/login/routes/test/login.rs b/src/login/routes/test/login.rs
index 10c17d6..81653ff 100644
--- a/src/login/routes/test/login.rs
+++ b/src/login/routes/test/login.rs
@@ -36,7 +36,7 @@ async fn new_identity() {
// Verify the semantics
let validated_at = fixtures::now();
- let validated = app
+ let (_, validated) = app
.logins()
.validate(&secret, &validated_at)
.await
@@ -73,7 +73,7 @@ async fn existing_identity() {
// Verify the semantics
let validated_at = fixtures::now();
- let validated_login = app
+ let (_, validated_login) = app
.logins()
.validate(&secret, &validated_at)
.await
diff --git a/src/login/types.rs b/src/login/types.rs
new file mode 100644
index 0000000..7c7cbf9
--- /dev/null
+++ b/src/login/types.rs
@@ -0,0 +1,12 @@
+use crate::repo::token;
+
+#[derive(Clone, Debug)]
+pub struct TokenRevoked {
+ pub token: token::Id,
+}
+
+impl From<token::Id> for TokenRevoked {
+ fn from(token: token::Id) -> Self {
+ Self { token }
+ }
+}
diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs
index c127078..ab61106 100644
--- a/src/repo/login/extract.rs
+++ b/src/repo/login/extract.rs
@@ -1,67 +1,15 @@
-use axum::{
- extract::{FromRequestParts, State},
- http::{request::Parts, StatusCode},
- response::{IntoResponse, Response},
-};
+use axum::{extract::FromRequestParts, http::request::Parts};
use super::Login;
-use crate::{
- app::App,
- clock::RequestedAt,
- error::Internal,
- login::{app::ValidateError, extract::IdentityToken},
-};
+use crate::{app::App, login::extract::Identity};
#[async_trait::async_trait]
impl FromRequestParts<App> for Login {
- type Rejection = LoginError<Internal>;
+ type Rejection = <Identity as FromRequestParts<App>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> {
- // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on
- // stable), the following can be replaced:
- //
- // ```
- // let Ok(identity_token) = IdentityToken::from_request_parts(
- // parts,
- // state,
- // ).await;
- // ```
- let identity_token = IdentityToken::from_request_parts(parts, state).await?;
- let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?;
+ let identity = Identity::from_request_parts(parts, state).await?;
- let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?;
-
- let app = State::<App>::from_request_parts(parts, state).await?;
- match app.logins().validate(&secret, &used_at).await {
- Ok(login) => Ok(login),
- Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized),
- Err(other) => Err(other.into()),
- }
- }
-}
-
-pub enum LoginError<E> {
- Failure(E),
- Unauthorized,
-}
-
-impl<E> IntoResponse for LoginError<E>
-where
- E: IntoResponse,
-{
- fn into_response(self) -> Response {
- match self {
- Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(),
- Self::Failure(e) => e.into_response(),
- }
- }
-}
-
-impl<E> From<E> for LoginError<Internal>
-where
- E: Into<Internal>,
-{
- fn from(err: E) -> Self {
- Self::Failure(err.into())
+ Ok(identity.login)
}
}
diff --git a/src/repo/token.rs b/src/repo/token.rs
index 15eef48..5f39e1d 100644
--- a/src/repo/token.rs
+++ b/src/repo/token.rs
@@ -1,8 +1,10 @@
+use std::fmt;
+
use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
use uuid::Uuid;
use super::login::{self, Login};
-use crate::{clock::DateTime, login::extract::IdentitySecret};
+use crate::{clock::DateTime, id::Id as BaseId, login::extract::IdentitySecret};
pub trait Provider {
fn tokens(&mut self) -> Tokens;
@@ -24,15 +26,17 @@ impl<'c> Tokens<'c> {
login: &Login,
issued_at: &DateTime,
) -> Result<IdentitySecret, sqlx::Error> {
+ let id = Id::generate();
let secret = Uuid::new_v4().to_string();
let secret = sqlx::query_scalar!(
r#"
insert
- into token (secret, login, issued_at, last_used_at)
- values ($1, $2, $3, $3)
+ into token (id, secret, login, issued_at, last_used_at)
+ values ($1, $2, $3, $4, $4)
returning secret as "secret!: IdentitySecret"
"#,
+ id,
secret,
login.id,
issued_at,
@@ -44,37 +48,38 @@ impl<'c> Tokens<'c> {
}
// Revoke a token by its secret.
- pub async fn revoke(&mut self, secret: &IdentitySecret) -> Result<(), sqlx::Error> {
- sqlx::query!(
+ pub async fn revoke(&mut self, secret: &IdentitySecret) -> Result<Id, sqlx::Error> {
+ let token = sqlx::query_scalar!(
r#"
delete
from token
where secret = $1
- returning 1 as "found: u32"
+ returning id as "id: Id"
"#,
secret,
)
.fetch_one(&mut *self.0)
.await?;
- Ok(())
+ Ok(token)
}
// Expire and delete all tokens that haven't been used more recently than
// `expire_at`.
- pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> {
- sqlx::query!(
+ pub async fn expire(&mut self, expire_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> {
+ let tokens = sqlx::query_scalar!(
r#"
delete
from token
where last_used_at < $1
+ returning id as "id: Id"
"#,
expire_at,
)
- .execute(&mut *self.0)
+ .fetch_all(&mut *self.0)
.await?;
- Ok(())
+ Ok(tokens)
}
// Validate a token by its secret, retrieving the associated Login record.
@@ -84,7 +89,7 @@ impl<'c> Tokens<'c> {
&mut self,
secret: &IdentitySecret,
used_at: &DateTime,
- ) -> Result<Login, sqlx::Error> {
+ ) -> Result<(Id, Login), sqlx::Error> {
// I would use `update … returning` to do this in one query, but
// sqlite3, as of this writing, does not allow an update's `returning`
// clause to reference columns from tables joined into the update. Two
@@ -101,21 +106,54 @@ impl<'c> Tokens<'c> {
.execute(&mut *self.0)
.await?;
- let login = sqlx::query_as!(
- Login,
+ let login = sqlx::query!(
r#"
select
- login.id as "id: login::Id",
- name
+ token.id as "token_id: Id",
+ login.id as "login_id: login::Id",
+ name as "login_name"
from login
join token on login.id = token.login
where token.secret = $1
"#,
secret,
)
+ .map(|row| {
+ (
+ row.token_id,
+ Login {
+ id: row.login_id,
+ name: row.login_name,
+ },
+ )
+ })
.fetch_one(&mut *self.0)
.await?;
Ok(login)
}
}
+
+// Stable identifier for a token. Prefixed with `T`.
+#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)]
+#[sqlx(transparent)]
+#[serde(transparent)]
+pub struct Id(BaseId);
+
+impl From<BaseId> for Id {
+ fn from(id: BaseId) -> Self {
+ Self(id)
+ }
+}
+
+impl Id {
+ pub fn generate() -> Self {
+ BaseId::generate("T")
+ }
+}
+
+impl fmt::Display for Id {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs
index 69b5f4c..bdd7881 100644
--- a/src/test/fixtures/identity.rs
+++ b/src/test/fixtures/identity.rs
@@ -3,7 +3,7 @@ use uuid::Uuid;
use crate::{
app::App,
clock::RequestedAt,
- login::extract::{IdentitySecret, IdentityToken},
+ login::extract::{Identity, IdentitySecret, IdentityToken},
password::Password,
};
@@ -22,6 +22,20 @@ pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt)
IdentityToken::new().set(token)
}
+pub async fn identity(app: &App, login: &(String, Password), issued_at: &RequestedAt) -> Identity {
+ let secret = logged_in(app, login, issued_at)
+ .await
+ .secret()
+ .expect("successful login generates a secret");
+ let (token, login) = app
+ .logins()
+ .validate(&secret, issued_at)
+ .await
+ .expect("always validates newly-issued secret");
+
+ Identity { token, login }
+}
+
pub fn secret(identity: &IdentityToken) -> IdentitySecret {
identity.secret().expect("identity contained a secret")
}