use std::{ ffi::{c_int, CStr, CString}, ptr::NonNull, str::from_utf8, }; use libsqlite3_sys::{ sqlite3, sqlite3_backup, sqlite3_backup_finish, sqlite3_backup_init, sqlite3_backup_step, sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_BUSY, SQLITE_LOCKED, SQLITE_OK, }; use sqlx::{ pool::PoolConnection, sqlite::{Sqlite, SqlitePool}, }; pub struct Builder { from: PoolConnection, } impl Builder { pub async fn to(self, to: &SqlitePool) -> sqlx::Result { Ok(Backup { from: self.from, to: to.acquire().await?, }) } } impl Backup { pub async fn from(from: &SqlitePool) -> sqlx::Result { Ok(Builder { from: from.acquire().await?, }) } } pub struct Backup { from: PoolConnection, to: PoolConnection, } impl Backup { pub async fn backup(&mut self) -> Result<(), Error> { let mut to = self.to.lock_handle().await?; let mut from = self.from.lock_handle().await?; let handle = Self::start(to.as_raw_handle(), from.as_raw_handle())?; let step_result = Self::step(handle, -1); Self::finish(to.as_raw_handle(), handle)?; step_result } fn start( to: NonNull, from: NonNull, ) -> Result, Error> { let name = CString::new("main").expect("static constant is a valid C string"); // Invariants: // // * `to` and `from` must be valid `sqlite3` pointers (guaranteed by sqlx) // * `zDestName` and `zSourceName` must be valid C strings (see above) // // Never evaluates to null (even though `sqlite3_backup_init` can). let handle = unsafe { sqlite3_backup_init(to.as_ptr(), name.as_ptr(), from.as_ptr(), name.as_ptr()) }; if handle.is_null() { Err(Error::Backup { code: Error::code_for(to), message: Error::message_for(to), })?; } // Having proven that `handle` is not null, we could use new_unchecked here. // Choosing not to so that any mistakes are caught, rather than causing // undefined behaviour later on. Ok(NonNull::new(handle).expect("backup handle is non-null")) } fn step(handle: NonNull, pages: c_int) -> Result<(), Error> { let step = unsafe { sqlite3_backup_step(handle.as_ptr(), pages) }; if SQLITE_BUSY == step { Err(Error::Backup { code: step, message: String::from("database busy"), }) } else if SQLITE_LOCKED == step { Err(Error::Backup { code: step, message: String::from("database locked"), }) } else { Ok(()) } } fn finish(to: NonNull, handle: NonNull) -> Result<(), Error> { let finished = unsafe { sqlite3_backup_finish(handle.as_ptr()) }; if finished == SQLITE_OK { Ok(()) } else { Err(Error::Backup { code: finished, message: Error::message_for(to), }) } } } #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] Sqlx(#[from] sqlx::Error), #[error("backup failed: {message} (code={code})")] Backup { code: c_int, message: String }, } impl Error { fn code_for(handle: NonNull) -> c_int { unsafe { sqlite3_extended_errcode(handle.as_ptr()) } } fn message_for(handle: NonNull) -> String { unsafe { let msg = sqlite3_errmsg(handle.as_ptr()); debug_assert!(!msg.is_null()); from_utf8(CStr::from_ptr(msg).to_bytes()) .expect("error messages from sqlite are always utf-8") .to_owned() } } }