summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/app.rs18
-rw-r--r--src/broadcast.rs78
-rw-r--r--src/channel/app.rs10
-rw-r--r--src/events/app.rs12
-rw-r--r--src/events/broadcaster.rs75
-rw-r--r--src/events/routes.rs5
-rw-r--r--src/events/routes/test.rs57
-rw-r--r--src/lib.rs1
-rw-r--r--src/login/app.rs58
-rw-r--r--src/login/broadcaster.rs3
-rw-r--r--src/login/extract.rs74
-rw-r--r--src/login/mod.rs2
-rw-r--r--src/login/routes/test/login.rs4
-rw-r--r--src/login/types.rs12
-rw-r--r--src/repo/login/extract.rs62
-rw-r--r--src/repo/token.rs70
-rw-r--r--src/test/fixtures/identity.rs16
17 files changed, 370 insertions, 187 deletions
diff --git a/src/app.rs b/src/app.rs
index 245feb1..c13f52f 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -2,33 +2,35 @@ use sqlx::sqlite::SqlitePool;
use crate::{
channel::app::Channels,
- events::{app::Events, broadcaster::Broadcaster},
- login::app::Logins,
+ events::{app::Events, broadcaster::Broadcaster as EventBroadcaster},
+ login::{app::Logins, broadcaster::Broadcaster as LoginBroadcaster},
};
#[derive(Clone)]
pub struct App {
db: SqlitePool,
- broadcaster: Broadcaster,
+ events: EventBroadcaster,
+ logins: LoginBroadcaster,
}
impl App {
pub fn from(db: SqlitePool) -> Self {
- let broadcaster = Broadcaster::default();
- Self { db, broadcaster }
+ let events = EventBroadcaster::default();
+ let logins = LoginBroadcaster::default();
+ Self { db, events, logins }
}
}
impl App {
pub const fn logins(&self) -> Logins {
- Logins::new(&self.db)
+ Logins::new(&self.db, &self.logins)
}
pub const fn events(&self) -> Events {
- Events::new(&self.db, &self.broadcaster)
+ Events::new(&self.db, &self.events)
}
pub const fn channels(&self) -> Channels {
- Channels::new(&self.db, &self.broadcaster)
+ Channels::new(&self.db, &self.events)
}
}
diff --git a/src/broadcast.rs b/src/broadcast.rs
new file mode 100644
index 0000000..083a301
--- /dev/null
+++ b/src/broadcast.rs
@@ -0,0 +1,78 @@
+use std::sync::{Arc, Mutex};
+
+use futures::{future, stream::StreamExt as _, Stream};
+use tokio::sync::broadcast::{channel, Sender};
+use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
+
+// Clones will share the same sender.
+#[derive(Clone)]
+pub struct Broadcaster<M> {
+ // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
+ // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
+ // lock it must be sync.
+ senders: Arc<Mutex<Sender<M>>>,
+}
+
+impl<M> Default for Broadcaster<M>
+where
+ M: Clone + Send + std::fmt::Debug + 'static,
+{
+ fn default() -> Self {
+ let sender = Self::make_sender();
+
+ Self {
+ senders: Arc::new(Mutex::new(sender)),
+ }
+ }
+}
+
+impl<M> Broadcaster<M>
+where
+ M: Clone + Send + std::fmt::Debug + 'static,
+{
+ // panic: if ``message.channel.id`` has not been previously registered,
+ // and was not part of the initial set of channels.
+ pub fn broadcast(&self, message: &M) {
+ let tx = self.sender();
+
+ // Per the Tokio docs, the returned error is only used to indicate that
+ // there are no receivers. In this use case, that's fine; a lack of
+ // listening consumers (chat clients) when a message is sent isn't an
+ // error.
+ //
+ // The successful return value, which includes the number of active
+ // receivers, also isn't that interesting to us.
+ let _ = tx.send(message.clone());
+ }
+
+ // panic: if ``channel`` has not been previously registered, and was not
+ // part of the initial set of channels.
+ pub fn subscribe(&self) -> impl Stream<Item = M> + std::fmt::Debug {
+ let rx = self.sender().subscribe();
+
+ BroadcastStream::from(rx).scan((), |(), r| {
+ future::ready(match r {
+ Ok(event) => Some(event),
+ // Stop the stream here. This will disconnect SSE clients
+ // (see `routes.rs`), who will then resume from
+ // `Last-Event-ID`, allowing them to catch up by reading
+ // the skipped messages from the database.
+ //
+ // See also:
+ // <https://users.rust-lang.org/t/taking-from-stream-while-ok/48854>
+ Err(BroadcastStreamRecvError::Lagged(_)) => None,
+ })
+ })
+ }
+
+ fn sender(&self) -> Sender<M> {
+ self.senders.lock().unwrap().clone()
+ }
+
+ fn make_sender() -> Sender<M> {
+ // Queue depth of 16 chosen entirely arbitrarily. Don't read too much
+ // into it.
+ let (tx, _) = channel(16);
+ tx
+ }
+}
diff --git a/src/channel/app.rs b/src/channel/app.rs
index d7312e4..70cda47 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -9,12 +9,12 @@ use crate::{
pub struct Channels<'a> {
db: &'a SqlitePool,
- broadcaster: &'a Broadcaster,
+ events: &'a Broadcaster,
}
impl<'a> Channels<'a> {
- pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self {
- Self { db, broadcaster }
+ pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self {
+ Self { db, events }
}
pub async fn create(&self, name: &str, created_at: &DateTime) -> Result<Channel, CreateError> {
@@ -26,7 +26,7 @@ impl<'a> Channels<'a> {
.map_err(|err| CreateError::from_duplicate_name(err, name))?;
tx.commit().await?;
- self.broadcaster
+ self.events
.broadcast(&ChannelEvent::created(channel.clone()));
Ok(channel)
@@ -60,7 +60,7 @@ impl<'a> Channels<'a> {
tx.commit().await?;
for event in events {
- self.broadcaster.broadcast(&event);
+ self.events.broadcast(&event);
}
Ok(())
diff --git a/src/events/app.rs b/src/events/app.rs
index 0cdc641..db7f430 100644
--- a/src/events/app.rs
+++ b/src/events/app.rs
@@ -24,12 +24,12 @@ use crate::{
pub struct Events<'a> {
db: &'a SqlitePool,
- broadcaster: &'a Broadcaster,
+ events: &'a Broadcaster,
}
impl<'a> Events<'a> {
- pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self {
- Self { db, broadcaster }
+ pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self {
+ Self { db, events }
}
pub async fn send(
@@ -51,7 +51,7 @@ impl<'a> Events<'a> {
.await?;
tx.commit().await?;
- self.broadcaster.broadcast(&event);
+ self.events.broadcast(&event);
Ok(event)
}
@@ -75,7 +75,7 @@ impl<'a> Events<'a> {
tx.commit().await?;
for event in events {
- self.broadcaster.broadcast(&event);
+ self.events.broadcast(&event);
}
Ok(())
@@ -101,7 +101,7 @@ impl<'a> Events<'a> {
// Subscribe before retrieving, to catch messages broadcast while we're
// querying the DB. We'll prune out duplicates later.
- let live_messages = self.broadcaster.subscribe();
+ let live_messages = self.events.subscribe();
let mut replays = BTreeMap::new();
let mut resume_live_at = resume_at.clone();
diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs
index 9697c0a..6b664cb 100644
--- a/src/events/broadcaster.rs
+++ b/src/events/broadcaster.rs
@@ -1,74 +1,3 @@
-use std::sync::{Arc, Mutex};
+use crate::{broadcast, events::types};
-use futures::{future, stream::StreamExt as _, Stream};
-use tokio::sync::broadcast::{channel, Sender};
-use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
-
-use crate::events::types;
-
-// Clones will share the same sender.
-#[derive(Clone)]
-pub struct Broadcaster {
- // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's
- // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that
- // lock it must be sync.
- senders: Arc<Mutex<Sender<types::ChannelEvent>>>,
-}
-
-impl Default for Broadcaster {
- fn default() -> Self {
- let sender = Self::make_sender();
-
- Self {
- senders: Arc::new(Mutex::new(sender)),
- }
- }
-}
-
-impl Broadcaster {
- // panic: if ``message.channel.id`` has not been previously registered,
- // and was not part of the initial set of channels.
- pub fn broadcast(&self, message: &types::ChannelEvent) {
- let tx = self.sender();
-
- // Per the Tokio docs, the returned error is only used to indicate that
- // there are no receivers. In this use case, that's fine; a lack of
- // listening consumers (chat clients) when a message is sent isn't an
- // error.
- //
- // The successful return value, which includes the number of active
- // receivers, also isn't that interesting to us.
- let _ = tx.send(message.clone());
- }
-
- // panic: if ``channel`` has not been previously registered, and was not
- // part of the initial set of channels.
- pub fn subscribe(&self) -> impl Stream<Item = types::ChannelEvent> + std::fmt::Debug {
- let rx = self.sender().subscribe();
-
- BroadcastStream::from(rx).scan((), |(), r| {
- future::ready(match r {
- Ok(event) => Some(event),
- // Stop the stream here. This will disconnect SSE clients
- // (see `routes.rs`), who will then resume from
- // `Last-Event-ID`, allowing them to catch up by reading
- // the skipped messages from the database.
- //
- // See also:
- // <https://users.rust-lang.org/t/taking-from-stream-while-ok/48854>
- Err(BroadcastStreamRecvError::Lagged(_)) => None,
- })
- })
- }
-
- fn sender(&self) -> Sender<types::ChannelEvent> {
- self.senders.lock().unwrap().clone()
- }
-
- fn make_sender() -> Sender<types::ChannelEvent> {
- // Queue depth of 16 chosen entirely arbitrarily. Don't read too much
- // into it.
- let (tx, _) = channel(16);
- tx
- }
-}
+pub type Broadcaster = broadcast::Broadcaster<types::ChannelEvent>;
diff --git a/src/events/routes.rs b/src/events/routes.rs
index 89c942c..ec9dae2 100644
--- a/src/events/routes.rs
+++ b/src/events/routes.rs
@@ -13,7 +13,7 @@ use super::{
extract::LastEventId,
types::{self, ResumePoint},
};
-use crate::{app::App, error::Internal, repo::login::Login};
+use crate::{app::App, error::Internal, login::extract::Identity};
#[cfg(test)]
mod test;
@@ -24,7 +24,7 @@ pub fn router() -> Router<App> {
async fn events(
State(app): State<App>,
- _: Login, // requires auth, but doesn't actually care who you are
+ identity: Identity,
last_event_id: Option<LastEventId<ResumePoint>>,
) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> {
let resume_at = last_event_id
@@ -32,6 +32,7 @@ async fn events(
.unwrap_or_default();
let stream = app.events().subscribe(resume_at).await?;
+ let stream = app.logins().limit_stream(identity.token, stream);
Ok(Events(stream))
}
diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs
index a6e2275..0b62b5b 100644
--- a/src/events/routes/test.rs
+++ b/src/events/routes/test.rs
@@ -20,7 +20,8 @@ async fn includes_historical_message() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let routes::Events(events) = routes::events(State(app), subscriber, None)
.await
.expect("subscribe never fails");
@@ -46,7 +47,8 @@ async fn includes_live_message() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let routes::Events(events) = routes::events(State(app.clone()), subscriber, None)
.await
.expect("subscribe never fails");
@@ -90,7 +92,8 @@ async fn includes_multiple_channels() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let routes::Events(events) = routes::events(State(app), subscriber, None)
.await
.expect("subscribe never fails");
@@ -127,7 +130,8 @@ async fn sequential_messages() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let routes::Events(events) = routes::events(State(app), subscriber, None)
.await
.expect("subscribe never fails");
@@ -166,7 +170,8 @@ async fn resumes_from() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let resume_at = {
// First subscription
@@ -232,7 +237,8 @@ async fn serial_resume() {
// Call the endpoint
- let subscriber = fixtures::login::create(&app).await;
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await;
let resume_at = {
let initial_messages = [
@@ -335,3 +341,42 @@ async fn serial_resume() {
}
};
}
+
+#[tokio::test]
+async fn terminates_on_token_expiry() {
+ // Set up the environment
+
+ let app = fixtures::scratch_app().await;
+ let channel = fixtures::channel::create(&app, &fixtures::now()).await;
+ let sender = fixtures::login::create(&app).await;
+
+ // Subscribe via the endpoint
+
+ let subscriber_creds = fixtures::login::create_with_password(&app).await;
+ let subscriber =
+ fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await;
+ let routes::Events(events) = routes::events(State(app.clone()), subscriber, None)
+ .await
+ .expect("subscribe never fails");
+
+ // Verify the resulting stream's behaviour
+
+ app.logins()
+ .expire(&fixtures::now())
+ .await
+ .expect("expiring tokens succeeds");
+
+ // These should not be delivered.
+ let messages = [
+ fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await,
+ ];
+
+ assert!(events
+ .filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event)))
+ .next()
+ .immediately()
+ .await
+ .is_none());
+}
diff --git a/src/lib.rs b/src/lib.rs
index 4139d4d..271118b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,6 +3,7 @@
#![warn(clippy::pedantic)]
mod app;
+mod broadcast;
mod channel;
pub mod cli;
mod clock;
diff --git a/src/login/app.rs b/src/login/app.rs
index f7fec88..b8916a8 100644
--- a/src/login/app.rs
+++ b/src/login/app.rs
@@ -1,24 +1,30 @@
use chrono::TimeDelta;
+use futures::{
+ future,
+ stream::{self, StreamExt as _},
+ Stream,
+};
use sqlx::sqlite::SqlitePool;
-use super::{extract::IdentitySecret, repo::auth::Provider as _};
+use super::{broadcaster::Broadcaster, extract::IdentitySecret, repo::auth::Provider as _, types};
use crate::{
clock::DateTime,
password::Password,
repo::{
error::NotFound as _,
login::{Login, Provider as _},
- token::Provider as _,
+ token::{self, Provider as _},
},
};
pub struct Logins<'a> {
db: &'a SqlitePool,
+ logins: &'a Broadcaster,
}
impl<'a> Logins<'a> {
- pub const fn new(db: &'a SqlitePool) -> Self {
- Self { db }
+ pub const fn new(db: &'a SqlitePool, logins: &'a Broadcaster) -> Self {
+ Self { db, logins }
}
pub async fn login(
@@ -63,7 +69,7 @@ impl<'a> Logins<'a> {
&self,
secret: &IdentitySecret,
used_at: &DateTime,
- ) -> Result<Login, ValidateError> {
+ ) -> Result<(token::Id, Login), ValidateError> {
let mut tx = self.db.begin().await?;
let login = tx
.tokens()
@@ -75,26 +81,56 @@ impl<'a> Logins<'a> {
Ok(login)
}
+ pub fn limit_stream<E>(
+ &self,
+ token: token::Id,
+ events: impl Stream<Item = E> + std::fmt::Debug,
+ ) -> impl Stream<Item = E> + std::fmt::Debug
+ where
+ E: std::fmt::Debug,
+ {
+ let token_events = self
+ .logins
+ .subscribe()
+ .filter(move |event| future::ready(event.token == token))
+ .map(|_| GuardedEvent::TokenRevoked);
+
+ let events = events.map(|event| GuardedEvent::Event(event));
+
+ stream::select(token_events, events).scan((), |(), event| {
+ future::ready(match event {
+ GuardedEvent::Event(event) => Some(event),
+ GuardedEvent::TokenRevoked => None,
+ })
+ })
+ }
+
pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> {
// Somewhat arbitrarily, expire after 7 days.
let expire_at = relative_to.to_owned() - TimeDelta::days(7);
let mut tx = self.db.begin().await?;
- tx.tokens().expire(&expire_at).await?;
+ let tokens = tx.tokens().expire(&expire_at).await?;
tx.commit().await?;
+ for event in tokens.into_iter().map(types::TokenRevoked::from) {
+ self.logins.broadcast(&event);
+ }
+
Ok(())
}
pub async fn logout(&self, secret: &IdentitySecret) -> Result<(), ValidateError> {
let mut tx = self.db.begin().await?;
- tx.tokens()
+ let token = tx
+ .tokens()
.revoke(secret)
.await
.not_found(|| ValidateError::InvalidToken)?;
-
tx.commit().await?;
+ self.logins.broadcast(&types::TokenRevoked::from(token));
+
Ok(())
}
}
@@ -124,3 +160,9 @@ pub enum ValidateError {
#[error(transparent)]
DatabaseError(#[from] sqlx::Error),
}
+
+#[derive(Debug)]
+enum GuardedEvent<E> {
+ TokenRevoked,
+ Event(E),
+}
diff --git a/src/login/broadcaster.rs b/src/login/broadcaster.rs
new file mode 100644
index 0000000..8e1fb3a
--- /dev/null
+++ b/src/login/broadcaster.rs
@@ -0,0 +1,3 @@
+use crate::{broadcast, login::types};
+
+pub type Broadcaster = broadcast::Broadcaster<types::TokenRevoked>;
diff --git a/src/login/extract.rs b/src/login/extract.rs
index 3b31d4c..b585565 100644
--- a/src/login/extract.rs
+++ b/src/login/extract.rs
@@ -1,12 +1,20 @@
use std::fmt;
use axum::{
- extract::FromRequestParts,
- http::request::Parts,
- response::{IntoResponseParts, ResponseParts},
+ extract::{FromRequestParts, State},
+ http::{request::Parts, StatusCode},
+ response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use axum_extra::extract::cookie::{Cookie, CookieJar};
+use crate::{
+ app::App,
+ clock::RequestedAt,
+ error::Internal,
+ login::app::ValidateError,
+ repo::{login::Login, token},
+};
+
// The usage pattern here - receive the extractor as an argument, return it in
// the response - is heavily modelled after CookieJar's own intended usage.
#[derive(Clone)]
@@ -112,3 +120,63 @@ where
Self(value.into())
}
}
+
+#[derive(Clone, Debug)]
+pub struct Identity {
+ pub token: token::Id,
+ pub login: Login,
+}
+
+#[async_trait::async_trait]
+impl FromRequestParts<App> for Identity {
+ type Rejection = LoginError<Internal>;
+
+ async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> {
+ // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on
+ // stable), the following can be replaced:
+ //
+ // ```
+ // let Ok(identity_token) = IdentityToken::from_request_parts(
+ // parts,
+ // state,
+ // ).await;
+ // ```
+ let identity_token = IdentityToken::from_request_parts(parts, state).await?;
+ let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?;
+
+ let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?;
+
+ let app = State::<App>::from_request_parts(parts, state).await?;
+ match app.logins().validate(&secret, &used_at).await {
+ Ok((token, login)) => Ok(Identity { token, login }),
+ Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized),
+ Err(other) => Err(other.into()),
+ }
+ }
+}
+
+pub enum LoginError<E> {
+ Failure(E),
+ Unauthorized,
+}
+
+impl<E> IntoResponse for LoginError<E>
+where
+ E: IntoResponse,
+{
+ fn into_response(self) -> Response {
+ match self {
+ Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(),
+ Self::Failure(e) => e.into_response(),
+ }
+ }
+}
+
+impl<E> From<E> for LoginError<Internal>
+where
+ E: Into<Internal>,
+{
+ fn from(err: E) -> Self {
+ Self::Failure(err.into())
+ }
+}
diff --git a/src/login/mod.rs b/src/login/mod.rs
index 191cce0..6ae82ac 100644
--- a/src/login/mod.rs
+++ b/src/login/mod.rs
@@ -1,6 +1,8 @@
pub use self::routes::router;
pub mod app;
+pub mod broadcaster;
pub mod extract;
mod repo;
mod routes;
+pub mod types;
diff --git a/src/login/routes/test/login.rs b/src/login/routes/test/login.rs
index 10c17d6..81653ff 100644
--- a/src/login/routes/test/login.rs
+++ b/src/login/routes/test/login.rs
@@ -36,7 +36,7 @@ async fn new_identity() {
// Verify the semantics
let validated_at = fixtures::now();
- let validated = app
+ let (_, validated) = app
.logins()
.validate(&secret, &validated_at)
.await
@@ -73,7 +73,7 @@ async fn existing_identity() {
// Verify the semantics
let validated_at = fixtures::now();
- let validated_login = app
+ let (_, validated_login) = app
.logins()
.validate(&secret, &validated_at)
.await
diff --git a/src/login/types.rs b/src/login/types.rs
new file mode 100644
index 0000000..7c7cbf9
--- /dev/null
+++ b/src/login/types.rs
@@ -0,0 +1,12 @@
+use crate::repo::token;
+
+#[derive(Clone, Debug)]
+pub struct TokenRevoked {
+ pub token: token::Id,
+}
+
+impl From<token::Id> for TokenRevoked {
+ fn from(token: token::Id) -> Self {
+ Self { token }
+ }
+}
diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs
index c127078..ab61106 100644
--- a/src/repo/login/extract.rs
+++ b/src/repo/login/extract.rs
@@ -1,67 +1,15 @@
-use axum::{
- extract::{FromRequestParts, State},
- http::{request::Parts, StatusCode},
- response::{IntoResponse, Response},
-};
+use axum::{extract::FromRequestParts, http::request::Parts};
use super::Login;
-use crate::{
- app::App,
- clock::RequestedAt,
- error::Internal,
- login::{app::ValidateError, extract::IdentityToken},
-};
+use crate::{app::App, login::extract::Identity};
#[async_trait::async_trait]
impl FromRequestParts<App> for Login {
- type Rejection = LoginError<Internal>;
+ type Rejection = <Identity as FromRequestParts<App>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> {
- // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on
- // stable), the following can be replaced:
- //
- // ```
- // let Ok(identity_token) = IdentityToken::from_request_parts(
- // parts,
- // state,
- // ).await;
- // ```
- let identity_token = IdentityToken::from_request_parts(parts, state).await?;
- let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?;
+ let identity = Identity::from_request_parts(parts, state).await?;
- let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?;
-
- let app = State::<App>::from_request_parts(parts, state).await?;
- match app.logins().validate(&secret, &used_at).await {
- Ok(login) => Ok(login),
- Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized),
- Err(other) => Err(other.into()),
- }
- }
-}
-
-pub enum LoginError<E> {
- Failure(E),
- Unauthorized,
-}
-
-impl<E> IntoResponse for LoginError<E>
-where
- E: IntoResponse,
-{
- fn into_response(self) -> Response {
- match self {
- Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(),
- Self::Failure(e) => e.into_response(),
- }
- }
-}
-
-impl<E> From<E> for LoginError<Internal>
-where
- E: Into<Internal>,
-{
- fn from(err: E) -> Self {
- Self::Failure(err.into())
+ Ok(identity.login)
}
}
diff --git a/src/repo/token.rs b/src/repo/token.rs
index 15eef48..5f39e1d 100644
--- a/src/repo/token.rs
+++ b/src/repo/token.rs
@@ -1,8 +1,10 @@
+use std::fmt;
+
use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
use uuid::Uuid;
use super::login::{self, Login};
-use crate::{clock::DateTime, login::extract::IdentitySecret};
+use crate::{clock::DateTime, id::Id as BaseId, login::extract::IdentitySecret};
pub trait Provider {
fn tokens(&mut self) -> Tokens;
@@ -24,15 +26,17 @@ impl<'c> Tokens<'c> {
login: &Login,
issued_at: &DateTime,
) -> Result<IdentitySecret, sqlx::Error> {
+ let id = Id::generate();
let secret = Uuid::new_v4().to_string();
let secret = sqlx::query_scalar!(
r#"
insert
- into token (secret, login, issued_at, last_used_at)
- values ($1, $2, $3, $3)
+ into token (id, secret, login, issued_at, last_used_at)
+ values ($1, $2, $3, $4, $4)
returning secret as "secret!: IdentitySecret"
"#,
+ id,
secret,
login.id,
issued_at,
@@ -44,37 +48,38 @@ impl<'c> Tokens<'c> {
}
// Revoke a token by its secret.
- pub async fn revoke(&mut self, secret: &IdentitySecret) -> Result<(), sqlx::Error> {
- sqlx::query!(
+ pub async fn revoke(&mut self, secret: &IdentitySecret) -> Result<Id, sqlx::Error> {
+ let token = sqlx::query_scalar!(
r#"
delete
from token
where secret = $1
- returning 1 as "found: u32"
+ returning id as "id: Id"
"#,
secret,
)
.fetch_one(&mut *self.0)
.await?;
- Ok(())
+ Ok(token)
}
// Expire and delete all tokens that haven't been used more recently than
// `expire_at`.
- pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> {
- sqlx::query!(
+ pub async fn expire(&mut self, expire_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> {
+ let tokens = sqlx::query_scalar!(
r#"
delete
from token
where last_used_at < $1
+ returning id as "id: Id"
"#,
expire_at,
)
- .execute(&mut *self.0)
+ .fetch_all(&mut *self.0)
.await?;
- Ok(())
+ Ok(tokens)
}
// Validate a token by its secret, retrieving the associated Login record.
@@ -84,7 +89,7 @@ impl<'c> Tokens<'c> {
&mut self,
secret: &IdentitySecret,
used_at: &DateTime,
- ) -> Result<Login, sqlx::Error> {
+ ) -> Result<(Id, Login), sqlx::Error> {
// I would use `update … returning` to do this in one query, but
// sqlite3, as of this writing, does not allow an update's `returning`
// clause to reference columns from tables joined into the update. Two
@@ -101,21 +106,54 @@ impl<'c> Tokens<'c> {
.execute(&mut *self.0)
.await?;
- let login = sqlx::query_as!(
- Login,
+ let login = sqlx::query!(
r#"
select
- login.id as "id: login::Id",
- name
+ token.id as "token_id: Id",
+ login.id as "login_id: login::Id",
+ name as "login_name"
from login
join token on login.id = token.login
where token.secret = $1
"#,
secret,
)
+ .map(|row| {
+ (
+ row.token_id,
+ Login {
+ id: row.login_id,
+ name: row.login_name,
+ },
+ )
+ })
.fetch_one(&mut *self.0)
.await?;
Ok(login)
}
}
+
+// Stable identifier for a token. Prefixed with `T`.
+#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)]
+#[sqlx(transparent)]
+#[serde(transparent)]
+pub struct Id(BaseId);
+
+impl From<BaseId> for Id {
+ fn from(id: BaseId) -> Self {
+ Self(id)
+ }
+}
+
+impl Id {
+ pub fn generate() -> Self {
+ BaseId::generate("T")
+ }
+}
+
+impl fmt::Display for Id {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs
index 69b5f4c..bdd7881 100644
--- a/src/test/fixtures/identity.rs
+++ b/src/test/fixtures/identity.rs
@@ -3,7 +3,7 @@ use uuid::Uuid;
use crate::{
app::App,
clock::RequestedAt,
- login::extract::{IdentitySecret, IdentityToken},
+ login::extract::{Identity, IdentitySecret, IdentityToken},
password::Password,
};
@@ -22,6 +22,20 @@ pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt)
IdentityToken::new().set(token)
}
+pub async fn identity(app: &App, login: &(String, Password), issued_at: &RequestedAt) -> Identity {
+ let secret = logged_in(app, login, issued_at)
+ .await
+ .secret()
+ .expect("successful login generates a secret");
+ let (token, login) = app
+ .logins()
+ .validate(&secret, issued_at)
+ .await
+ .expect("always validates newly-issued secret");
+
+ Identity { token, login }
+}
+
pub fn secret(identity: &IdentityToken) -> IdentitySecret {
identity.secret().expect("identity contained a secret")
}