summaryrefslogtreecommitdiff
path: root/src/db/backup.rs
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2024-10-05 00:15:45 -0400
committerOwen Jacobson <owen@grimoire.ca>2024-10-05 00:27:29 -0400
commite1551113323d5a496b826d7b0265b1be6235f45c (patch)
tree08f09cac579c954c782e39d5cd02c7ae72f86374 /src/db/backup.rs
parentb422be184e01b4cc35b9c9a6921379080c24edb3 (diff)
Make a backup of the `.hi` database before applying migrations.
This was motivated by Kit and I both independently discovering that sqlite3 will happily partially apply migrations, leaving the DB in a broken state.
Diffstat (limited to 'src/db/backup.rs')
-rw-r--r--src/db/backup.rs136
1 files changed, 136 insertions, 0 deletions
diff --git a/src/db/backup.rs b/src/db/backup.rs
new file mode 100644
index 0000000..e34df9f
--- /dev/null
+++ b/src/db/backup.rs
@@ -0,0 +1,136 @@
+use std::{
+ ffi::{c_int, CStr, CString},
+ ptr::NonNull,
+ str::from_utf8_unchecked,
+};
+
+use libsqlite3_sys::{
+ sqlite3, sqlite3_backup, sqlite3_backup_finish, sqlite3_backup_init, sqlite3_backup_step,
+ sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_BUSY, SQLITE_LOCKED, SQLITE_OK,
+};
+use sqlx::{
+ pool::PoolConnection,
+ sqlite::{Sqlite, SqlitePool},
+};
+
+pub struct Builder {
+ from: PoolConnection<Sqlite>,
+}
+
+impl Builder {
+ pub async fn to(self, to: &SqlitePool) -> sqlx::Result<Backup> {
+ Ok(Backup {
+ from: self.from,
+ to: to.acquire().await?,
+ })
+ }
+}
+
+impl Backup {
+ pub async fn from(from: &SqlitePool) -> sqlx::Result<Builder> {
+ Ok(Builder {
+ from: from.acquire().await?,
+ })
+ }
+}
+
+pub struct Backup {
+ from: PoolConnection<Sqlite>,
+ to: PoolConnection<Sqlite>,
+}
+
+impl Backup {
+ pub async fn backup(&mut self) -> Result<(), Error> {
+ let mut to = self.to.lock_handle().await?;
+ let mut from = self.from.lock_handle().await?;
+
+ let handle = Self::start(to.as_raw_handle(), from.as_raw_handle())?;
+ let step_result = Self::step(handle, -1);
+ Self::finish(to.as_raw_handle(), handle)?;
+
+ step_result
+ }
+
+ fn start(to: NonNull<sqlite3>, from: NonNull<sqlite3>) -> Result<*mut sqlite3_backup, Error> {
+ let name = CString::new("main").expect("static constant is a valid C string");
+ unsafe {
+ // Invariants:
+ //
+ // * `to` and `from` must be valid `sqlite3` pointers (guaranteed by sqlx)
+ // * `zDestName` and `zSourceName` must be valid C strings (see above)
+ //
+ // Never evaluates to null (even though `sqlite3_backup_init` can).
+ let handle =
+ sqlite3_backup_init(to.as_ptr(), name.as_ptr(), from.as_ptr(), name.as_ptr());
+ if handle.is_null() {
+ Err(Error::Backup {
+ code: Error::code_for(to),
+ message: Error::message_for(to),
+ })?;
+ }
+ Ok(handle)
+ }
+ }
+
+ fn step(handle: *mut sqlite3_backup, pages: c_int) -> Result<(), Error> {
+ let step = unsafe {
+ // Invariants:
+ //
+ // * `handle` must be a valid backup handle (see above).
+ sqlite3_backup_step(handle, pages)
+ };
+ if SQLITE_BUSY == step {
+ Err(Error::Backup {
+ code: step,
+ message: String::from("database busy"),
+ })
+ } else if SQLITE_LOCKED == step {
+ Err(Error::Backup {
+ code: step,
+ message: String::from("database locked"),
+ })
+ } else {
+ Ok(())
+ }
+ }
+
+ fn finish(to: NonNull<sqlite3>, handle: *mut sqlite3_backup) -> Result<(), Error> {
+ let finished = unsafe {
+ // Invariants:
+ //
+ // * `handle` must be a valid backup handle (see above).
+ sqlite3_backup_finish(handle)
+ };
+ if finished == SQLITE_OK {
+ Ok(())
+ } else {
+ Err(Error::Backup {
+ code: finished,
+ message: Error::message_for(to),
+ })
+ }
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error(transparent)]
+ Sqlx(#[from] sqlx::Error),
+ #[error("backup failed: {message} (code={code})")]
+ Backup { code: c_int, message: String },
+}
+
+impl Error {
+ fn code_for(handle: NonNull<sqlite3>) -> c_int {
+ unsafe { sqlite3_extended_errcode(handle.as_ptr()) }
+ }
+
+ fn message_for(handle: NonNull<sqlite3>) -> String {
+ unsafe {
+ let msg = sqlite3_errmsg(handle.as_ptr());
+ debug_assert!(!msg.is_null());
+
+ from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()).to_owned()
+ }
+ }
+}