mod backup; use std::str::FromStr; use hex_literal::hex; use sqlx::{ error::{DatabaseError, ErrorKind}, migrate::{Migrate as _, MigrateDatabase as _}, sqlite::{Sqlite, SqliteConnectOptions, SqlitePool, SqlitePoolOptions}, }; pub async fn prepare(url: &str, backup_url: &str) -> Result { if backup_url != "sqlite::memory:" && Sqlite::database_exists(backup_url).await? { return Err(Error::BackupExists(backup_url.into())); } let pool = create(url).await?; // First migration of original migration series, from commit // 9bd6d9862b1c243def02200bca2cfbf578ad2a2f or earlier. reject_migration(&pool, "20240831024047", "login", &hex!("9949D238C4099295EC4BEE734BFDA8D87513B2973DFB895352A11AB01DD46CB95314B7F1B3431B77E3444A165FE3DC28")).await?; // Original version of this migration was buggy, but didn't require a // database reset to fix. migration_replaced( &pool, "20241009031441", &hex!("4B5873397C8BA9CFAF49172EE6DE455CD643A27BD71032ECD8EFA7684362FE620A8F6B27D493AF8D9A570C38CC1A6416"), &hex!("E5CDEDA38F2BCE4C24A45E58D3BDE3FF2C30B1431C3B01870BB9DEB142E5A200B9C850C3C72A45D352C15D8DB51B8467"), ).await?; let backup_pool = create(backup_url).await?; backup::Backup::from(&pool) .to(&backup_pool) .backup() .await?; if let Err(migrate_error) = sqlx::migrate!().run(&pool).await { if let Err(restore_error) = backup::Backup::from(&backup_pool).to(&pool).backup().await { Err(Error::Restore(restore_error, migrate_error))?; } else if let Err(drop_error) = Sqlite::drop_database(backup_url).await { Err(Error::Drop(drop_error, migrate_error))?; } else { Err(migrate_error)?; }; } Sqlite::drop_database(backup_url).await?; Ok(pool) } async fn create(database_url: &str) -> sqlx::Result { let options = SqliteConnectOptions::from_str(database_url)? .create_if_missing(true) .optimize_on_close(true, /* analysis_limit */ None); let pool = SqlitePoolOptions::new().connect_with(options).await?; Ok(pool) } async fn reject_migration( pool: &SqlitePool, version: &str, description: &str, checksum: &[u8], ) -> Result<(), Error> { let mut conn = pool.acquire().await?; conn.ensure_migrations_table().await?; let applied = conn.list_applied_migrations().await?; for migration in applied { if migration.checksum == checksum { return Err(Error::Rejected(version.into(), description.into())); } } Ok(()) } async fn migration_replaced( pool: &SqlitePool, version: &str, original: &[u8], replacement: &[u8], ) -> Result<(), sqlx::Error> { let mut conn = pool.acquire().await?; conn.ensure_migrations_table().await?; sqlx::query!( r#" update _sqlx_migrations set checksum = $1 where version = $2 and checksum = $3 "#, replacement, version, original, ) .execute(&mut *conn) .await?; Ok(()) } /// Errors occurring during database setup. #[derive(Debug, thiserror::Error)] pub enum Error { /// Failure due to a database error. See [`sqlx::Error`]. #[error(transparent)] Database(#[from] sqlx::Error), /// Failure because an existing database backup already exists. #[error("backup from a previous failed migration already exists: {0}")] BackupExists(String), /// Failure due to a database backup error. See [`backup::Error`]. #[error(transparent)] Backup(#[from] backup::Error), #[error("migration failed: {1}\nrestoring backup failed: {0}")] Restore(backup::Error, sqlx::migrate::MigrateError), #[error( "migration failed: {1}\nrestoring from backup succeeded, but deleting backup failed: {0}" )] Drop(sqlx::Error, sqlx::migrate::MigrateError), /// Failure due to a database migration error. See /// [`sqlx::migrate::MigrateError`]. #[error(transparent)] Migration(#[from] sqlx::migrate::MigrateError), /// Failure because the database contains a migration from an unsupported /// schema version. #[error("database contains rejected migration {0}:{1}, move it aside")] Rejected(String, String), } pub trait NotFound: Sized { type Ok; type Error; fn not_found(self, map: F) -> Result where E: From, F: FnOnce() -> E, { self.optional()?.ok_or_else(map) } fn optional(self) -> Result, Self::Error>; } impl NotFound for Result { type Ok = T; type Error = sqlx::Error; fn optional(self) -> Result, sqlx::Error> { match self { Ok(value) => Ok(Some(value)), Err(sqlx::Error::RowNotFound) => Ok(None), Err(other) => Err(other), } } } pub trait Duplicate { type Ok; type Error; fn duplicate(self, map: F) -> Result where E: From, F: FnOnce() -> E; } impl Duplicate for Result { type Ok = T; type Error = sqlx::Error; fn duplicate(self, map: F) -> Result where E: From, F: FnOnce() -> E, { match self { Ok(value) => Ok(value), Err(error) => match error.as_database_error().map(DatabaseError::kind) { Some(ErrorKind::UniqueViolation) => Err(map()), _ => Err(error.into()), }, } } }