use std::{ ffi::{c_int, CStr, CString}, ptr::NonNull, str::from_utf8, }; use libsqlite3_sys::{ sqlite3, sqlite3_backup, sqlite3_backup_finish, sqlite3_backup_init, sqlite3_backup_step, sqlite3_errmsg, sqlite3_errstr, sqlite3_extended_errcode, SQLITE_DONE, SQLITE_OK, }; use sqlx::sqlite::SqlitePool; pub struct Builder<'p> { from: &'p SqlitePool, } impl<'p> Builder<'p> { pub fn to(self, to: &'p SqlitePool) -> Backup<'p> { Backup { from: self.from, to, } } } pub struct Backup<'p> { from: &'p SqlitePool, to: &'p SqlitePool, } impl<'p> Backup<'p> { pub fn from(from: &'p SqlitePool) -> Builder<'p> { Builder { from } } } impl<'p> Backup<'p> { pub async fn backup(&mut self) -> Result<(), Error> { let mut to = self.to.acquire().await?; let mut to = to.lock_handle().await?; let mut from = self.from.acquire().await?; let mut from = from.lock_handle().await?; let handle = Self::start(to.as_raw_handle(), from.as_raw_handle())?; let step_result = loop { match Self::step(handle, -1) { Err(error) => break Err(error), Ok(SQLITE_DONE) => break Ok(()), Ok(SQLITE_OK) => (), // keep pumping the backup step function Ok(other) => panic!("unexpected step result: {other}"), } }; Self::finish(handle)?; step_result } fn start( to: NonNull, from: NonNull, ) -> Result, Error> { let name = CString::new("main").expect("static constant is a valid C string"); // Invariants: // // * `to` and `from` must be valid `sqlite3` pointers (guaranteed by sqlx) // * `zDestName` and `zSourceName` must be valid C strings (see above) // // Never evaluates to null (even though `sqlite3_backup_init` can). let handle = unsafe { sqlite3_backup_init(to.as_ptr(), name.as_ptr(), from.as_ptr(), name.as_ptr()) }; if handle.is_null() { Err(Error::from_handle(to))?; } // Having proven that `handle` is not null, we could use new_unchecked here. // Choosing not to so that any mistakes are caught, rather than causing // undefined behaviour later on. Ok(NonNull::new(handle).expect("backup handle is non-null")) } fn step(handle: NonNull, pages: c_int) -> Result { let step = unsafe { sqlite3_backup_step(handle.as_ptr(), pages) }; if [SQLITE_DONE, SQLITE_OK].contains(&step) { Ok(step) } else { Err(Error::from_code(step)) } } fn finish(handle: NonNull) -> Result<(), Error> { let finished = unsafe { sqlite3_backup_finish(handle.as_ptr()) }; if SQLITE_OK == finished { Ok(()) } else { Err(Error::from_code(finished)) } } } #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] Sqlx(#[from] sqlx::Error), #[error("backup failed: {message} (code={code})")] Backup { code: c_int, message: String }, } impl Error { pub fn from_handle(handle: NonNull) -> Self { Self::Backup { code: Self::code_for(handle), message: Self::message_for(handle), } } pub fn from_code(code: c_int) -> Self { Self::Backup { code, message: Self::message_from_code(code), } } fn code_for(handle: NonNull) -> c_int { unsafe { sqlite3_extended_errcode(handle.as_ptr()) } } fn message_for(handle: NonNull) -> String { Self::message_from(|| unsafe { sqlite3_errmsg(handle.as_ptr()) }) } fn message_from_code(code: c_int) -> String { Self::message_from(|| unsafe { sqlite3_errstr(code) }) } fn message_from(f: impl FnOnce() -> *const i8) -> String { let msg = f(); debug_assert!(!msg.is_null()); from_utf8(unsafe { CStr::from_ptr(msg) }.to_bytes()) // This is actually promised in the Sqlite3 docs, but we check anyways to catch // mistakes. See . .expect("error messages from sqlite are always utf-8") .to_owned() } }