use std::{ ffi::{c_int, CStr, CString}, ptr::NonNull, str::from_utf8_unchecked, }; use libsqlite3_sys::{ sqlite3, sqlite3_backup, sqlite3_backup_finish, sqlite3_backup_init, sqlite3_backup_step, sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_BUSY, SQLITE_LOCKED, SQLITE_OK, }; use sqlx::{ pool::PoolConnection, sqlite::{Sqlite, SqlitePool}, }; pub struct Builder { from: PoolConnection, } impl Builder { pub async fn to(self, to: &SqlitePool) -> sqlx::Result { Ok(Backup { from: self.from, to: to.acquire().await?, }) } } impl Backup { pub async fn from(from: &SqlitePool) -> sqlx::Result { Ok(Builder { from: from.acquire().await?, }) } } pub struct Backup { from: PoolConnection, to: PoolConnection, } impl Backup { pub async fn backup(&mut self) -> Result<(), Error> { let mut to = self.to.lock_handle().await?; let mut from = self.from.lock_handle().await?; let handle = Self::start(to.as_raw_handle(), from.as_raw_handle())?; let step_result = Self::step(handle, -1); Self::finish(to.as_raw_handle(), handle)?; step_result } fn start(to: NonNull, from: NonNull) -> Result<*mut sqlite3_backup, Error> { let name = CString::new("main").expect("static constant is a valid C string"); unsafe { // Invariants: // // * `to` and `from` must be valid `sqlite3` pointers (guaranteed by sqlx) // * `zDestName` and `zSourceName` must be valid C strings (see above) // // Never evaluates to null (even though `sqlite3_backup_init` can). let handle = sqlite3_backup_init(to.as_ptr(), name.as_ptr(), from.as_ptr(), name.as_ptr()); if handle.is_null() { Err(Error::Backup { code: Error::code_for(to), message: Error::message_for(to), })?; } Ok(handle) } } fn step(handle: *mut sqlite3_backup, pages: c_int) -> Result<(), Error> { let step = unsafe { // Invariants: // // * `handle` must be a valid backup handle (see above). sqlite3_backup_step(handle, pages) }; if SQLITE_BUSY == step { Err(Error::Backup { code: step, message: String::from("database busy"), }) } else if SQLITE_LOCKED == step { Err(Error::Backup { code: step, message: String::from("database locked"), }) } else { Ok(()) } } fn finish(to: NonNull, handle: *mut sqlite3_backup) -> Result<(), Error> { let finished = unsafe { // Invariants: // // * `handle` must be a valid backup handle (see above). sqlite3_backup_finish(handle) }; if finished == SQLITE_OK { Ok(()) } else { Err(Error::Backup { code: finished, message: Error::message_for(to), }) } } } #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] Sqlx(#[from] sqlx::Error), #[error("backup failed: {message} (code={code})")] Backup { code: c_int, message: String }, } impl Error { fn code_for(handle: NonNull) -> c_int { unsafe { sqlite3_extended_errcode(handle.as_ptr()) } } fn message_for(handle: NonNull) -> String { unsafe { let msg = sqlite3_errmsg(handle.as_ptr()); debug_assert!(!msg.is_null()); from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()).to_owned() } } }