summaryrefslogtreecommitdiff
path: root/src/cli.rs
blob: 729a7918b692b4359720132eff4f1d0830fc7bb0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
use std::io;

use axum::{middleware, Router};
use clap::Parser;
use sqlx::sqlite::SqlitePool;
use tokio::net;

use crate::{app::App, channel, clock, events, login, repo::pool};

pub type Result<T> = std::result::Result<T, Error>;

#[derive(Parser)]
pub struct Args {
    #[arg(short, long, env, default_value = "localhost")]
    address: String,

    #[arg(short, long, env, default_value_t = 64209)]
    port: u16,

    #[arg(short, long, env, default_value = "sqlite://.hi")]
    database_url: String,
}

impl Args {
    pub async fn run(self) -> Result<()> {
        let pool = self.pool().await?;

        let app = App::from(pool).await?;
        let app = routers()
            .route_layer(middleware::from_fn(clock::middleware))
            .with_state(app);

        let listener = self.listener().await?;
        let started_msg = started_msg(&listener)?;

        let serve = axum::serve(listener, app);
        println!("{started_msg}");
        serve.await?;

        Ok(())
    }

    async fn listener(&self) -> io::Result<net::TcpListener> {
        let listen_addr = self.listen_addr();
        let listener = tokio::net::TcpListener::bind(listen_addr).await?;
        Ok(listener)
    }

    fn listen_addr(&self) -> impl net::ToSocketAddrs + '_ {
        (self.address.as_str(), self.port)
    }

    async fn pool(&self) -> sqlx::Result<SqlitePool> {
        pool::prepare(&self.database_url).await
    }
}

fn routers() -> Router<App> {
    [channel::router(), events::router(), login::router()]
        .into_iter()
        .fold(Router::default(), Router::merge)
}

fn started_msg(listener: &net::TcpListener) -> io::Result<String> {
    let local_addr = listener.local_addr()?;
    Ok(format!("listening on http://{local_addr}/"))
}

#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub enum Error {
    IoError(#[from] io::Error),
    DatabaseError(#[from] sqlx::Error),
    MigrateError(#[from] sqlx::migrate::MigrateError),
}