mod backup; use std::str::FromStr; use sqlx::{ error::{DatabaseError, ErrorKind}, migrate::MigrateDatabase as _, sqlite::{Sqlite, SqliteConnectOptions, SqlitePool, SqlitePoolOptions}, }; pub async fn prepare(url: &str, backup_url: &str) -> Result { if backup_url != "sqlite::memory:" && Sqlite::database_exists(backup_url).await? { return Err(Error::BackupExists(backup_url.into())); } let pool = create(url).await?; let backup_pool = create(backup_url).await?; backup::Backup::from(&pool) .to(&backup_pool) .backup() .await?; if let Err(migrate_error) = sqlx::migrate!().run(&pool).await { if let Err(restore_error) = backup::Backup::from(&backup_pool).to(&pool).backup().await { Err(Error::Restore(restore_error, migrate_error))?; } else if let Err(drop_error) = Sqlite::drop_database(backup_url).await { Err(Error::Drop(drop_error, migrate_error))?; } else { Err(migrate_error)?; } } Sqlite::drop_database(backup_url).await?; Ok(pool) } async fn create(database_url: &str) -> sqlx::Result { let options = SqliteConnectOptions::from_str(database_url)? .create_if_missing(true) .optimize_on_close(true, /* analysis_limit */ None); let pool = SqlitePoolOptions::new().connect_with(options).await?; Ok(pool) } #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] Database(#[from] sqlx::Error), #[error("backup from a previous failed migration already exists: {0}")] BackupExists(String), #[error(transparent)] Backup(#[from] backup::Error), #[error("migration failed: {1}\nrestoring backup failed: {0}")] Restore(backup::Error, sqlx::migrate::MigrateError), #[error( "migration failed: {1}\nrestoring from backup succeeded, but deleting backup failed: {0}" )] Drop(sqlx::Error, sqlx::migrate::MigrateError), #[error(transparent)] Migration(#[from] sqlx::migrate::MigrateError), } pub trait NotFound: Sized { type Ok; type Error; fn not_found(self, map: F) -> Result where E: From, F: FnOnce() -> E, { self.optional()?.ok_or_else(map) } fn optional(self) -> Result, Self::Error>; } impl NotFound for Result { type Ok = T; type Error = sqlx::Error; fn optional(self) -> Result, sqlx::Error> { match self { Ok(value) => Ok(Some(value)), Err(sqlx::Error::RowNotFound) => Ok(None), Err(other) => Err(other), } } } pub trait Duplicate { type Ok; type Error; fn duplicate(self, map: F) -> Result where E: From, F: FnOnce() -> E; } impl Duplicate for Result { type Ok = T; type Error = sqlx::Error; fn duplicate(self, map: F) -> Result where E: From, F: FnOnce() -> E, { match self { Ok(value) => Ok(value), Err(error) => match error.as_database_error().map(DatabaseError::kind) { Some(ErrorKind::UniqueViolation) => Err(map()), _ => Err(error.into()), }, } } }