diff options
Diffstat (limited to 'src/cli/mod.rs')
| -rw-r--r-- | src/cli/mod.rs | 172 |
1 files changed, 172 insertions, 0 deletions
diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..c75ce2b --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,172 @@ +//! The `hi` command-line interface. +//! +//! This module supports running `hi` as a freestanding program, via the +//! [`Args`] struct. + +use std::{future, io}; + +use axum::{ + http::header, + middleware, + response::{IntoResponse, Response}, + Router, +}; +use clap::{CommandFactory, Parser}; +use sqlx::sqlite::SqlitePool; +use tokio::net; + +use crate::{ + app::App, + boot, channel, clock, db, event, expire, invite, login, message, + setup::{self, middleware::setup_required}, + ui, +}; + +pub mod recanonicalize; + +/// Command-line entry point for running the `hi` server. +/// +/// This is intended to be used as a Clap [Parser], to capture command-line +/// arguments for the `hi` server: +/// +/// ```no_run +/// # use hi::cli::Error; +/// # +/// # #[tokio::main] +/// # async fn main() -> Result<(), Error> { +/// use clap::Parser; +/// use hi::cli::Args; +/// +/// let args = Args::parse(); +/// args.run().await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Parser)] +#[command( + version, + about = "Run the `hi` server.", + long_about = r#"Run the `hi` server. + +The database at `--database-url` will be created, or upgraded, automatically."# +)] +pub struct Args { + /// The network address `hi` should listen on + #[arg(short, long, env, default_value = "localhost")] + address: String, + + /// The network port `hi` should listen on + #[arg(short, long, env, default_value_t = 64209)] + port: u16, + + /// Sqlite URL or path for the `hi` database + #[arg(short, long, env, default_value = "sqlite://.hi")] + database_url: String, + + /// Sqlite URL or path for a backup of the `hi` database during upgrades + #[arg(short = 'D', long, env, default_value = "sqlite://.hi.backup")] + backup_database_url: String, +} + +impl Args { + /// Runs the `hi` server, using the parsed configuation in `self`. + /// + /// This will perform the following tasks: + /// + /// * Migrate the `hi` database (at `--database-url`). + /// * Start an HTTP server (on the interface and port controlled by + /// `--address` and `--port`). + /// * Print a status message. + /// * Wait for that server to shut down. + /// + /// # Errors + /// + /// Will return `Err` if the server is unable to start, or terminates + /// prematurely. The specific [`Error`] variant will expose the cause + /// of the failure. + pub async fn run(self) -> Result<(), Error> { + let pool = self.pool().await?; + + let app = App::from(pool); + let app = routers(&app) + .route_layer(middleware::from_fn_with_state( + app.clone(), + expire::middleware, + )) + .route_layer(middleware::from_fn(clock::middleware)) + .route_layer(middleware::map_response(Self::server_info())) + .with_state(app); + + let listener = self.listener().await?; + let started_msg = started_msg(&listener)?; + + let serve = axum::serve(listener, app); + println!("{started_msg}"); + serve.await?; + + Ok(()) + } + + async fn listener(&self) -> io::Result<net::TcpListener> { + let listen_addr = self.listen_addr(); + let listener = tokio::net::TcpListener::bind(listen_addr).await?; + Ok(listener) + } + + fn listen_addr(&self) -> impl net::ToSocketAddrs + '_ { + (self.address.as_str(), self.port) + } + + async fn pool(&self) -> Result<SqlitePool, db::Error> { + db::prepare(&self.database_url, &self.backup_database_url).await + } + + fn server_info() -> impl Clone + Fn(Response) -> future::Ready<Response> { + let command = Self::command(); + let name = command.get_name(); + let version = command.get_version().unwrap_or("unknown version"); + let version = format!("{name}/{version}"); + move |resp| { + let response = ([(header::SERVER, &version)], resp).into_response(); + future::ready(response) + } + } +} + +fn routers(app: &App) -> Router<App> { + [ + [ + // API endpoints that require setup to function + boot::router(), + channel::router(), + event::router(), + invite::router(), + login::router(), + message::router(), + ] + .into_iter() + .fold(Router::default(), Router::merge) + .route_layer(middleware::from_fn_with_state(app.clone(), setup_required)), + // API endpoints that handle setup + setup::router(), + // The UI (handles setup state itself) + ui::router(app), + ] + .into_iter() + .fold(Router::default(), Router::merge) +} + +fn started_msg(listener: &net::TcpListener) -> io::Result<String> { + let local_addr = listener.local_addr()?; + Ok(format!("listening on http://{local_addr}/")) +} + +/// Errors that can be raised by [`Args::run`]. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum Error { + /// Failure due to `io::Error`. See [`io::Error`]. + Io(#[from] io::Error), + /// Failure due to a database initialization error. See [`db::Error`]. + Database(#[from] db::Error), +} |
