diff options
Diffstat (limited to 'src/db/backup.rs')
| -rw-r--r-- | src/db/backup.rs | 57 |
1 files changed, 34 insertions, 23 deletions
diff --git a/src/db/backup.rs b/src/db/backup.rs index a6fe917..212fa4e 100644 --- a/src/db/backup.rs +++ b/src/db/backup.rs @@ -6,7 +6,8 @@ use std::{ use libsqlite3_sys::{ sqlite3, sqlite3_backup, sqlite3_backup_finish, sqlite3_backup_init, sqlite3_backup_step, - sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_BUSY, SQLITE_LOCKED, SQLITE_OK, + sqlite3_errmsg, sqlite3_errstr, sqlite3_extended_errcode, SQLITE_BUSY, SQLITE_LOCKED, + SQLITE_OK, }; use sqlx::sqlite::SqlitePool; @@ -43,7 +44,7 @@ impl<'p> Backup<'p> { let handle = Self::start(to.as_raw_handle(), from.as_raw_handle())?; let step_result = Self::step(handle, -1); - Self::finish(to.as_raw_handle(), handle)?; + Self::finish(handle)?; step_result } @@ -63,10 +64,7 @@ impl<'p> Backup<'p> { sqlite3_backup_init(to.as_ptr(), name.as_ptr(), from.as_ptr(), name.as_ptr()) }; if handle.is_null() { - Err(Error::Backup { - code: Error::code_for(to), - message: Error::message_for(to), - })?; + Err(Error::from_handle(to))?; } // Having proven that `handle` is not null, we could use new_unchecked here. // Choosing not to so that any mistakes are caught, rather than causing @@ -76,30 +74,19 @@ impl<'p> Backup<'p> { fn step(handle: NonNull<sqlite3_backup>, pages: c_int) -> Result<(), Error> { let step = unsafe { sqlite3_backup_step(handle.as_ptr(), pages) }; - if SQLITE_BUSY == step { - Err(Error::Backup { - code: step, - message: String::from("database busy"), - }) - } else if SQLITE_LOCKED == step { - Err(Error::Backup { - code: step, - message: String::from("database locked"), - }) + if [SQLITE_BUSY, SQLITE_LOCKED].contains(&step) { + Err(Error::from_code(step)) } else { Ok(()) } } - fn finish(to: NonNull<sqlite3>, handle: NonNull<sqlite3_backup>) -> Result<(), Error> { + fn finish(handle: NonNull<sqlite3_backup>) -> Result<(), Error> { let finished = unsafe { sqlite3_backup_finish(handle.as_ptr()) }; - if finished == SQLITE_OK { + if SQLITE_OK == finished { Ok(()) } else { - Err(Error::Backup { - code: finished, - message: Error::message_for(to), - }) + Err(Error::from_code(finished)) } } } @@ -113,14 +100,38 @@ pub enum Error { } impl Error { + pub fn from_handle(handle: NonNull<sqlite3>) -> Self { + Self::Backup { + code: Self::code_for(handle), + message: Self::message_for(handle), + } + } + + pub fn from_code(code: c_int) -> Self { + Self::Backup { + code, + message: Self::message_from_code(code), + } + } + fn code_for(handle: NonNull<sqlite3>) -> c_int { unsafe { sqlite3_extended_errcode(handle.as_ptr()) } } fn message_for(handle: NonNull<sqlite3>) -> String { - let msg = unsafe { sqlite3_errmsg(handle.as_ptr()) }; + Self::message_from(|| unsafe { sqlite3_errmsg(handle.as_ptr()) }) + } + + fn message_from_code(code: c_int) -> String { + Self::message_from(|| unsafe { sqlite3_errstr(code) }) + } + + fn message_from(f: impl FnOnce() -> *const i8) -> String { + let msg = f(); debug_assert!(!msg.is_null()); from_utf8(unsafe { CStr::from_ptr(msg) }.to_bytes()) + // This is actually promised in the Sqlite3 docs, but we check anyways to catch + // mistakes. See <https://www.sqlite.org/c3ref/errcode.html>. .expect("error messages from sqlite are always utf-8") .to_owned() } |
