summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml6
-rw-r--r--src/cli.rs6
-rw-r--r--src/db/backup.rs136
-rw-r--r--src/db/mod.rs (renamed from src/db.rs)45
-rw-r--r--src/test/fixtures/mod.rs2
6 files changed, 190 insertions, 6 deletions
diff --git a/Cargo.lock b/Cargo.lock
index b1f0582..9bfb430 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -788,6 +788,7 @@ dependencies = [
"futures",
"headers",
"itertools",
+ "libsqlite3-sys",
"password-hash",
"rand",
"rand_core",
diff --git a/Cargo.toml b/Cargo.toml
index 2b2e774..28f4747 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -4,6 +4,11 @@ version = "0.1.0"
edition = "2021"
[dependencies]
+# Pinned to keep sqlx and libsqlite3 in lockstep. See also:
+# <https://docs.rs/sqlx/latest/sqlx/sqlite/index.html>
+sqlx = { version = "=0.8.2", features = ["chrono", "runtime-tokio", "sqlite"] }
+libsqlite3-sys = { version = "=0.30.1", features = ["bundled"] }
+
argon2 = "0.5.3"
async-trait = "0.1.83"
axum = { version = "0.7.6", features = ["form"] }
@@ -18,7 +23,6 @@ rand = "0.8.5"
rand_core = { version = "0.6.4", features = ["getrandom"] }
serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0.128"
-sqlx = { version = "0.8.2", features = ["chrono", "runtime-tokio", "sqlite"] }
thiserror = "1.0.64"
tokio = { version = "1.40.0", features = ["rt", "macros", "rt-multi-thread"] }
tokio-stream = { version = "0.1.16", features = ["sync"] }
diff --git a/src/cli.rs b/src/cli.rs
index d88916a..31dd4ce 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -49,6 +49,10 @@ pub struct Args {
/// Sqlite URL or path for the `hi` database
#[arg(short, long, env, default_value = "sqlite://.hi")]
database_url: String,
+
+ /// Sqlite URL or path for a backup of the `hi` database during upgrades
+ #[arg(short = 'D', long, env, default_value = "sqlite://.hi.backup")]
+ backup_database_url: String,
}
impl Args {
@@ -100,7 +104,7 @@ impl Args {
}
async fn pool(&self) -> Result<SqlitePool, db::Error> {
- db::prepare(&self.database_url).await
+ db::prepare(&self.database_url, &self.backup_database_url).await
}
}
diff --git a/src/db/backup.rs b/src/db/backup.rs
new file mode 100644
index 0000000..e34df9f
--- /dev/null
+++ b/src/db/backup.rs
@@ -0,0 +1,136 @@
+use std::{
+ ffi::{c_int, CStr, CString},
+ ptr::NonNull,
+ str::from_utf8_unchecked,
+};
+
+use libsqlite3_sys::{
+ sqlite3, sqlite3_backup, sqlite3_backup_finish, sqlite3_backup_init, sqlite3_backup_step,
+ sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_BUSY, SQLITE_LOCKED, SQLITE_OK,
+};
+use sqlx::{
+ pool::PoolConnection,
+ sqlite::{Sqlite, SqlitePool},
+};
+
+pub struct Builder {
+ from: PoolConnection<Sqlite>,
+}
+
+impl Builder {
+ pub async fn to(self, to: &SqlitePool) -> sqlx::Result<Backup> {
+ Ok(Backup {
+ from: self.from,
+ to: to.acquire().await?,
+ })
+ }
+}
+
+impl Backup {
+ pub async fn from(from: &SqlitePool) -> sqlx::Result<Builder> {
+ Ok(Builder {
+ from: from.acquire().await?,
+ })
+ }
+}
+
+pub struct Backup {
+ from: PoolConnection<Sqlite>,
+ to: PoolConnection<Sqlite>,
+}
+
+impl Backup {
+ pub async fn backup(&mut self) -> Result<(), Error> {
+ let mut to = self.to.lock_handle().await?;
+ let mut from = self.from.lock_handle().await?;
+
+ let handle = Self::start(to.as_raw_handle(), from.as_raw_handle())?;
+ let step_result = Self::step(handle, -1);
+ Self::finish(to.as_raw_handle(), handle)?;
+
+ step_result
+ }
+
+ fn start(to: NonNull<sqlite3>, from: NonNull<sqlite3>) -> Result<*mut sqlite3_backup, Error> {
+ let name = CString::new("main").expect("static constant is a valid C string");
+ unsafe {
+ // Invariants:
+ //
+ // * `to` and `from` must be valid `sqlite3` pointers (guaranteed by sqlx)
+ // * `zDestName` and `zSourceName` must be valid C strings (see above)
+ //
+ // Never evaluates to null (even though `sqlite3_backup_init` can).
+ let handle =
+ sqlite3_backup_init(to.as_ptr(), name.as_ptr(), from.as_ptr(), name.as_ptr());
+ if handle.is_null() {
+ Err(Error::Backup {
+ code: Error::code_for(to),
+ message: Error::message_for(to),
+ })?;
+ }
+ Ok(handle)
+ }
+ }
+
+ fn step(handle: *mut sqlite3_backup, pages: c_int) -> Result<(), Error> {
+ let step = unsafe {
+ // Invariants:
+ //
+ // * `handle` must be a valid backup handle (see above).
+ sqlite3_backup_step(handle, pages)
+ };
+ if SQLITE_BUSY == step {
+ Err(Error::Backup {
+ code: step,
+ message: String::from("database busy"),
+ })
+ } else if SQLITE_LOCKED == step {
+ Err(Error::Backup {
+ code: step,
+ message: String::from("database locked"),
+ })
+ } else {
+ Ok(())
+ }
+ }
+
+ fn finish(to: NonNull<sqlite3>, handle: *mut sqlite3_backup) -> Result<(), Error> {
+ let finished = unsafe {
+ // Invariants:
+ //
+ // * `handle` must be a valid backup handle (see above).
+ sqlite3_backup_finish(handle)
+ };
+ if finished == SQLITE_OK {
+ Ok(())
+ } else {
+ Err(Error::Backup {
+ code: finished,
+ message: Error::message_for(to),
+ })
+ }
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error(transparent)]
+ Sqlx(#[from] sqlx::Error),
+ #[error("backup failed: {message} (code={code})")]
+ Backup { code: c_int, message: String },
+}
+
+impl Error {
+ fn code_for(handle: NonNull<sqlite3>) -> c_int {
+ unsafe { sqlite3_extended_errcode(handle.as_ptr()) }
+ }
+
+ fn message_for(handle: NonNull<sqlite3>) -> String {
+ unsafe {
+ let msg = sqlite3_errmsg(handle.as_ptr());
+ debug_assert!(!msg.is_null());
+
+ from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()).to_owned()
+ }
+ }
+}
diff --git a/src/db.rs b/src/db/mod.rs
index e09b0ba..61d5c18 100644
--- a/src/db.rs
+++ b/src/db/mod.rs
@@ -1,15 +1,46 @@
+mod backup;
+
use std::str::FromStr;
-use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
+use sqlx::{
+ migrate::MigrateDatabase as _,
+ sqlite::{Sqlite, SqliteConnectOptions, SqlitePool, SqlitePoolOptions},
+};
+
+pub async fn prepare(url: &str, backup_url: &str) -> Result<SqlitePool, Error> {
+ if backup_url != "sqlite::memory:" && Sqlite::database_exists(backup_url).await? {
+ return Err(Error::BackupExists(backup_url.into()));
+ }
-pub async fn prepare(url: &str) -> Result<SqlitePool, Error> {
let pool = create(url).await?;
// First migration of original migration series, from commit
// 9bd6d9862b1c243def02200bca2cfbf578ad2a2f or earlier.
reject_migration(&pool, "20240831024047", "login", "9949D238C4099295EC4BEE734BFDA8D87513B2973DFB895352A11AB01DD46CB95314B7F1B3431B77E3444A165FE3DC28").await?;
- sqlx::migrate!().run(&pool).await?;
+ let backup_pool = create(backup_url).await?;
+ backup::Backup::from(&pool)
+ .await?
+ .to(&backup_pool)
+ .await?
+ .backup()
+ .await?;
+
+ if let Err(migrate_error) = sqlx::migrate!().run(&pool).await {
+ if let Err(restore_error) = backup::Backup::from(&backup_pool)
+ .await?
+ .to(&pool)
+ .await?
+ .backup()
+ .await
+ {
+ Err(Error::Restore(restore_error, migrate_error))?;
+ } else {
+ Err(migrate_error)?;
+ };
+ }
+
+ Sqlite::drop_database(backup_url).await?;
Ok(pool)
}
@@ -70,6 +101,14 @@ pub enum Error {
/// Failure due to a database error. See [`sqlx::Error`].
#[error(transparent)]
Database(#[from] sqlx::Error),
+ /// Failure because an existing database backup already exists.
+ #[error("backup from a previous failed migration already exists: {0}")]
+ BackupExists(String),
+ /// Failure due to a database backup error. See [`backup::Error`].
+ #[error(transparent)]
+ Backup(#[from] backup::Error),
+ #[error("backing out failed migration also failed: {0} ({1})")]
+ Restore(backup::Error, sqlx::migrate::MigrateError),
/// Failure due to a database migration error. See
/// [`sqlx::migrate::MigrateError`].
#[error(transparent)]
diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs
index c5efa9b..41f7e13 100644
--- a/src/test/fixtures/mod.rs
+++ b/src/test/fixtures/mod.rs
@@ -11,7 +11,7 @@ pub mod login;
pub mod message;
pub async fn scratch_app() -> App {
- let pool = db::prepare("sqlite::memory:")
+ let pool = db::prepare("sqlite::memory:", "sqlite::memory:")
.await
.expect("setting up in-memory sqlite database");
App::from(pool)