diff options
Diffstat (limited to 'src/repo/login')
| -rw-r--r-- | src/repo/login/extract.rs | 55 | ||||
| -rw-r--r-- | src/repo/login/mod.rs | 4 | ||||
| -rw-r--r-- | src/repo/login/store.rs | 104 |
3 files changed, 163 insertions, 0 deletions
diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs new file mode 100644 index 0000000..a068bc0 --- /dev/null +++ b/src/repo/login/extract.rs @@ -0,0 +1,55 @@ +use axum::{ + extract::{FromRequestParts, State}, + http::{request::Parts, StatusCode}, + response::{IntoResponse, Response}, +}; + +use super::Login; +use crate::{app::App, clock::RequestedAt, error::InternalError, login::extract::IdentityToken}; + +#[async_trait::async_trait] +impl FromRequestParts<App> for Login { + type Rejection = LoginError<InternalError>; + + async fn from_request_parts(parts: &mut Parts, state: &App) -> Result<Self, Self::Rejection> { + // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on + // stable), the following can be replaced: + // + // let Ok(identity_token) = IdentityToken::from_request_parts(parts, state).await; + let identity_token = IdentityToken::from_request_parts(parts, state).await?; + let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?; + + let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; + + let app = State::<App>::from_request_parts(parts, state).await?; + let login = app.logins().validate(secret, used_at).await?; + + login.ok_or(LoginError::Unauthorized) + } +} + +pub enum LoginError<E> { + Failure(E), + Unauthorized, +} + +impl<E> IntoResponse for LoginError<E> +where + E: IntoResponse, +{ + fn into_response(self) -> Response { + match self { + Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(), + Self::Failure(e) => e.into_response(), + } + } +} + +impl<E> From<E> for LoginError<InternalError> +where + E: Into<InternalError>, +{ + fn from(err: E) -> Self { + Self::Failure(err.into()) + } +} diff --git a/src/repo/login/mod.rs b/src/repo/login/mod.rs new file mode 100644 index 0000000..e23a7b7 --- /dev/null +++ b/src/repo/login/mod.rs @@ -0,0 +1,4 @@ +mod extract; +mod store; + +pub use self::store::{Id, Login, Logins, Provider}; diff --git a/src/repo/login/store.rs b/src/repo/login/store.rs new file mode 100644 index 0000000..24dd744 --- /dev/null +++ b/src/repo/login/store.rs @@ -0,0 +1,104 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{id::Id as BaseId, password::StoredHash}; + +pub trait Provider { + fn logins(&mut self) -> Logins; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn logins(&mut self) -> Logins { + Logins(self) + } +} + +pub struct Logins<'t>(&'t mut SqliteConnection); + +// This also implements FromRequestParts (see `./extract.rs`). As a result, it +// can be used as an extractor for endpoints that want to require login, or for +// endpoints that need to behave differently depending on whether the client is +// or is not logged in. +#[derive(Clone, Debug, serde::Serialize)] +pub struct Login { + pub id: Id, + pub name: String, + // The omission of the hashed password is deliberate, to minimize the + // chance that it ends up tangled up in debug output or in some other chunk + // of logic elsewhere. +} + +impl<'c> Logins<'c> { + pub async fn create( + &mut self, + name: &str, + password_hash: &StoredHash, + ) -> Result<Login, sqlx::Error> { + let id = Id::generate(); + + let login = sqlx::query_as!( + Login, + r#" + insert or fail + into login (id, name, password_hash) + values ($1, $2, $3) + returning + id as "id: Id", + name + "#, + id, + name, + password_hash, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(login) + } + + pub async fn by_id(&mut self, id: &Id) -> Result<Login, sqlx::Error> { + let login = sqlx::query_as!( + Login, + r#" + select + id as "id: Id", + name + from login + where id = $1 + "#, + id, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(login) + } +} + +impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { + fn from(tx: &'t mut SqliteConnection) -> Self { + Self(tx) + } +} + +/// Stable identifier for a [Login]. Prefixed with `L`. +#[derive(Clone, Debug, Eq, PartialEq, sqlx::Type, serde::Serialize)] +#[sqlx(transparent)] +pub struct Id(BaseId); + +impl From<BaseId> for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("L") + } +} + +impl std::fmt::Display for Id { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} |
