From 2f621416f69e522f412e34ebbf29e655541414bd Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Sat, 5 Oct 2024 01:15:41 -0400 Subject: Use the right functions for determining error messages. --- src/db/backup.rs | 57 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 23 deletions(-) (limited to 'src/db/backup.rs') diff --git a/src/db/backup.rs b/src/db/backup.rs index a6fe917..212fa4e 100644 --- a/src/db/backup.rs +++ b/src/db/backup.rs @@ -6,7 +6,8 @@ use std::{ use libsqlite3_sys::{ sqlite3, sqlite3_backup, sqlite3_backup_finish, sqlite3_backup_init, sqlite3_backup_step, - sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_BUSY, SQLITE_LOCKED, SQLITE_OK, + sqlite3_errmsg, sqlite3_errstr, sqlite3_extended_errcode, SQLITE_BUSY, SQLITE_LOCKED, + SQLITE_OK, }; use sqlx::sqlite::SqlitePool; @@ -43,7 +44,7 @@ impl<'p> Backup<'p> { let handle = Self::start(to.as_raw_handle(), from.as_raw_handle())?; let step_result = Self::step(handle, -1); - Self::finish(to.as_raw_handle(), handle)?; + Self::finish(handle)?; step_result } @@ -63,10 +64,7 @@ impl<'p> Backup<'p> { sqlite3_backup_init(to.as_ptr(), name.as_ptr(), from.as_ptr(), name.as_ptr()) }; if handle.is_null() { - Err(Error::Backup { - code: Error::code_for(to), - message: Error::message_for(to), - })?; + Err(Error::from_handle(to))?; } // Having proven that `handle` is not null, we could use new_unchecked here. // Choosing not to so that any mistakes are caught, rather than causing @@ -76,30 +74,19 @@ impl<'p> Backup<'p> { fn step(handle: NonNull, pages: c_int) -> Result<(), Error> { let step = unsafe { sqlite3_backup_step(handle.as_ptr(), pages) }; - if SQLITE_BUSY == step { - Err(Error::Backup { - code: step, - message: String::from("database busy"), - }) - } else if SQLITE_LOCKED == step { - Err(Error::Backup { - code: step, - message: String::from("database locked"), - }) + if [SQLITE_BUSY, SQLITE_LOCKED].contains(&step) { + Err(Error::from_code(step)) } else { Ok(()) } } - fn finish(to: NonNull, handle: NonNull) -> Result<(), Error> { + fn finish(handle: NonNull) -> Result<(), Error> { let finished = unsafe { sqlite3_backup_finish(handle.as_ptr()) }; - if finished == SQLITE_OK { + if SQLITE_OK == finished { Ok(()) } else { - Err(Error::Backup { - code: finished, - message: Error::message_for(to), - }) + Err(Error::from_code(finished)) } } } @@ -113,14 +100,38 @@ pub enum Error { } impl Error { + pub fn from_handle(handle: NonNull) -> Self { + Self::Backup { + code: Self::code_for(handle), + message: Self::message_for(handle), + } + } + + pub fn from_code(code: c_int) -> Self { + Self::Backup { + code, + message: Self::message_from_code(code), + } + } + fn code_for(handle: NonNull) -> c_int { unsafe { sqlite3_extended_errcode(handle.as_ptr()) } } fn message_for(handle: NonNull) -> String { - let msg = unsafe { sqlite3_errmsg(handle.as_ptr()) }; + Self::message_from(|| unsafe { sqlite3_errmsg(handle.as_ptr()) }) + } + + fn message_from_code(code: c_int) -> String { + Self::message_from(|| unsafe { sqlite3_errstr(code) }) + } + + fn message_from(f: impl FnOnce() -> *const i8) -> String { + let msg = f(); debug_assert!(!msg.is_null()); from_utf8(unsafe { CStr::from_ptr(msg) }.to_bytes()) + // This is actually promised in the Sqlite3 docs, but we check anyways to catch + // mistakes. See . .expect("error messages from sqlite are always utf-8") .to_owned() } -- cgit v1.2.3