summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/app.rs9
-rw-r--r--src/bin/hi-recanonicalize.rs9
-rw-r--r--src/bin/hi.rs (renamed from src/main.rs)0
-rw-r--r--src/channel/app.rs8
-rw-r--r--src/channel/repo.rs32
-rw-r--r--src/cli/mod.rs (renamed from src/cli.rs)2
-rw-r--r--src/cli/recanonicalize.rs86
-rw-r--r--src/login/app.rs22
-rw-r--r--src/login/mod.rs1
-rw-r--r--src/login/repo.rs33
10 files changed, 197 insertions, 5 deletions
diff --git a/src/app.rs b/src/app.rs
index cb05061..6d71259 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -5,14 +5,12 @@ use crate::{
channel::app::Channels,
event::{self, app::Events},
invite::app::Invites,
+ login::app::Logins,
message::app::Messages,
setup::app::Setup,
token::{self, app::Tokens},
};
-#[cfg(test)]
-use crate::login::app::Logins;
-
#[derive(Clone)]
pub struct App {
db: SqlitePool,
@@ -49,6 +47,11 @@ impl App {
Invites::new(&self.db)
}
+ #[cfg(not(test))]
+ pub const fn logins(&self) -> Logins {
+ Logins::new(&self.db)
+ }
+
#[cfg(test)]
pub const fn logins(&self) -> Logins {
Logins::new(&self.db, &self.events)
diff --git a/src/bin/hi-recanonicalize.rs b/src/bin/hi-recanonicalize.rs
new file mode 100644
index 0000000..4081276
--- /dev/null
+++ b/src/bin/hi-recanonicalize.rs
@@ -0,0 +1,9 @@
+use clap::Parser;
+
+use hi::cli;
+
+#[tokio::main]
+async fn main() -> Result<(), cli::recanonicalize::Error> {
+ let args = cli::recanonicalize::Args::parse();
+ args.run().await
+}
diff --git a/src/main.rs b/src/bin/hi.rs
index d0830ff..d0830ff 100644
--- a/src/main.rs
+++ b/src/bin/hi.rs
diff --git a/src/channel/app.rs b/src/channel/app.rs
index b8ceeb0..7bfa0f7 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -133,6 +133,14 @@ impl<'a> Channels<'a> {
Ok(())
}
+
+ pub async fn recanonicalize(&self) -> Result<(), sqlx::Error> {
+ let mut tx = self.db.begin().await?;
+ tx.channels().recanonicalize().await?;
+ tx.commit().await?;
+
+ Ok(())
+ }
}
#[derive(Debug, thiserror::Error)]
diff --git a/src/channel/repo.rs b/src/channel/repo.rs
index 4baa95b..e26ac2b 100644
--- a/src/channel/repo.rs
+++ b/src/channel/repo.rs
@@ -300,6 +300,38 @@ impl<'c> Channels<'c> {
Ok(channels)
}
+
+ pub async fn recanonicalize(&mut self) -> Result<(), sqlx::Error> {
+ let channels = sqlx::query!(
+ r#"
+ select
+ id as "id: Id",
+ display_name as "display_name: String"
+ from channel_name
+ "#,
+ )
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ for channel in channels {
+ let name = Name::from(channel.display_name);
+ let canonical_name = name.canonical();
+
+ sqlx::query!(
+ r#"
+ update channel_name
+ set canonical_name = $1
+ where id = $2
+ "#,
+ canonical_name,
+ channel.id,
+ )
+ .execute(&mut *self.0)
+ .await?;
+ }
+
+ Ok(())
+ }
}
#[derive(Debug, thiserror::Error)]
diff --git a/src/cli.rs b/src/cli/mod.rs
index 0659851..c75ce2b 100644
--- a/src/cli.rs
+++ b/src/cli/mod.rs
@@ -22,6 +22,8 @@ use crate::{
ui,
};
+pub mod recanonicalize;
+
/// Command-line entry point for running the `hi` server.
///
/// This is intended to be used as a Clap [Parser], to capture command-line
diff --git a/src/cli/recanonicalize.rs b/src/cli/recanonicalize.rs
new file mode 100644
index 0000000..5f8a1db
--- /dev/null
+++ b/src/cli/recanonicalize.rs
@@ -0,0 +1,86 @@
+use sqlx::sqlite::SqlitePool;
+
+use crate::{app::App, db};
+
+/// Command-line entry point for repairing canonical names in the `hi` database.
+/// This command may be necessary after an upgrade, if the canonical forms of
+/// names has changed. It will re-calculate the canonical form of each name in
+/// the database, based on its display form, and store the results back to the
+/// database.
+///
+/// This is intended to be used as a Clap [Parser], to capture command-line
+/// arguments for the `hi-recanonicalize` command:
+///
+/// ```no_run
+/// # use hi::recanonicalize::cli::Error;
+/// #
+/// # #[tokio::main]
+/// # async fn main() -> Result<(), Error> {
+/// use clap::Parser;
+/// use hi::cli::recanonicalize::Args;
+///
+/// let args = Args::parse();
+/// args.run().await?;
+/// # Ok(())
+/// # }
+/// ```
+#[derive(clap::Parser)]
+#[command(
+ version,
+ about = "Recanonicalize names in the `hi` database.",
+ long_about = r#"Recanonicalize names in the `hi` database.
+
+The `hi` server must not be running while this command is run.
+
+The database at `--database-url` will also be created, or upgraded, automatically."#
+)]
+pub struct Args {
+ /// Sqlite URL or path for the `hi` database
+ #[arg(short, long, env, default_value = "sqlite://.hi")]
+ database_url: String,
+
+ /// Sqlite URL or path for a backup of the `hi` database during upgrades
+ #[arg(short = 'D', long, env, default_value = "sqlite://.hi.backup")]
+ backup_database_url: String,
+}
+
+impl Args {
+ /// Recanonicalizes the `hi` database, using the parsed configuation in
+ /// `self`.
+ ///
+ /// This will perform the following tasks:
+ ///
+ /// * Migrate the `hi` database (at `--database-url`).
+ /// * Recanonicalize names in the `login` and `channel` tables.
+ ///
+ /// # Errors
+ ///
+ /// Will return `Err` if the canonicalization or database upgrade processes
+ /// fail. The specific [`Error`] variant will expose the cause
+ /// of the failure.
+ pub async fn run(self) -> Result<(), Error> {
+ let pool = self.pool().await?;
+
+ let app = App::from(pool);
+ app.logins().recanonicalize().await?;
+ app.channels().recanonicalize().await?;
+
+ Ok(())
+ }
+
+ async fn pool(&self) -> Result<SqlitePool, db::Error> {
+ db::prepare(&self.database_url, &self.backup_database_url).await
+ }
+}
+
+/// Errors that can be raised by [`Args::run`].
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum Error {
+ // /// Failure due to `io::Error`. See [`io::Error`].
+ // Io(#[from] io::Error),
+ /// Failure due to a database initialization error. See [`db::Error`].
+ Database(#[from] db::Error),
+ /// Failure due to a data manipulation error. See [`sqlx::Error`].
+ Sqlx(#[from] sqlx::Error),
+}
diff --git a/src/login/app.rs b/src/login/app.rs
index 37f1249..2f5896f 100644
--- a/src/login/app.rs
+++ b/src/login/app.rs
@@ -1,6 +1,10 @@
use sqlx::sqlite::SqlitePool;
-use super::{repo::Provider as _, Login, Password};
+use super::repo::Provider as _;
+
+#[cfg(test)]
+use super::{Login, Password};
+#[cfg(test)]
use crate::{
clock::DateTime,
event::{repo::Provider as _, Broadcaster, Event},
@@ -9,14 +13,22 @@ use crate::{
pub struct Logins<'a> {
db: &'a SqlitePool,
+ #[cfg(test)]
events: &'a Broadcaster,
}
impl<'a> Logins<'a> {
+ #[cfg(not(test))]
+ pub const fn new(db: &'a SqlitePool) -> Self {
+ Self { db }
+ }
+
+ #[cfg(test)]
pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self {
Self { db, events }
}
+ #[cfg(test)]
pub async fn create(
&self,
name: &Name,
@@ -35,6 +47,14 @@ impl<'a> Logins<'a> {
Ok(login.as_created())
}
+
+ pub async fn recanonicalize(&self) -> Result<(), sqlx::Error> {
+ let mut tx = self.db.begin().await?;
+ tx.logins().recanonicalize().await?;
+ tx.commit().await?;
+
+ Ok(())
+ }
}
#[derive(Debug, thiserror::Error)]
diff --git a/src/login/mod.rs b/src/login/mod.rs
index 98cc3d7..64a3698 100644
--- a/src/login/mod.rs
+++ b/src/login/mod.rs
@@ -1,4 +1,3 @@
-#[cfg(test)]
pub mod app;
pub mod event;
pub mod extract;
diff --git a/src/login/repo.rs b/src/login/repo.rs
index 6021f26..c6bc734 100644
--- a/src/login/repo.rs
+++ b/src/login/repo.rs
@@ -89,6 +89,7 @@ impl<'c> Logins<'c> {
Ok(logins)
}
+
pub async fn replay(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> {
let logins = sqlx::query!(
r#"
@@ -119,6 +120,38 @@ impl<'c> Logins<'c> {
Ok(logins)
}
+
+ pub async fn recanonicalize(&mut self) -> Result<(), sqlx::Error> {
+ let logins = sqlx::query!(
+ r#"
+ select
+ id as "id: Id",
+ display_name as "display_name: String"
+ from login
+ "#,
+ )
+ .fetch_all(&mut *self.0)
+ .await?;
+
+ for login in logins {
+ let name = Name::from(login.display_name);
+ let canonical_name = name.canonical();
+
+ sqlx::query!(
+ r#"
+ update login
+ set canonical_name = $1
+ where id = $2
+ "#,
+ canonical_name,
+ login.id,
+ )
+ .execute(&mut *self.0)
+ .await?;
+ }
+
+ Ok(())
+ }
}
#[derive(Debug, thiserror::Error)]