summaryrefslogtreecommitdiff
path: root/src/db/backup.rs
blob: a6fe9174c66931aac69babff442b9ba4bfc9c652 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use std::{
    ffi::{c_int, CStr, CString},
    ptr::NonNull,
    str::from_utf8,
};

use libsqlite3_sys::{
    sqlite3, sqlite3_backup, sqlite3_backup_finish, sqlite3_backup_init, sqlite3_backup_step,
    sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_BUSY, SQLITE_LOCKED, SQLITE_OK,
};
use sqlx::sqlite::SqlitePool;

pub struct Builder<'p> {
    from: &'p SqlitePool,
}

impl<'p> Builder<'p> {
    pub fn to(self, to: &'p SqlitePool) -> Backup<'p> {
        Backup {
            from: self.from,
            to,
        }
    }
}

pub struct Backup<'p> {
    from: &'p SqlitePool,
    to: &'p SqlitePool,
}

impl<'p> Backup<'p> {
    pub fn from(from: &'p SqlitePool) -> Builder<'p> {
        Builder { from }
    }
}

impl<'p> Backup<'p> {
    pub async fn backup(&mut self) -> Result<(), Error> {
        let mut to = self.to.acquire().await?;
        let mut to = to.lock_handle().await?;
        let mut from = self.from.acquire().await?;
        let mut from = from.lock_handle().await?;

        let handle = Self::start(to.as_raw_handle(), from.as_raw_handle())?;
        let step_result = Self::step(handle, -1);
        Self::finish(to.as_raw_handle(), handle)?;

        step_result
    }

    fn start(
        to: NonNull<sqlite3>,
        from: NonNull<sqlite3>,
    ) -> Result<NonNull<sqlite3_backup>, Error> {
        let name = CString::new("main").expect("static constant is a valid C string");
        // Invariants:
        //
        // * `to` and `from` must be valid `sqlite3` pointers (guaranteed by sqlx)
        // * `zDestName` and `zSourceName` must be valid C strings (see above)
        //
        // Never evaluates to null (even though `sqlite3_backup_init` can).
        let handle = unsafe {
            sqlite3_backup_init(to.as_ptr(), name.as_ptr(), from.as_ptr(), name.as_ptr())
        };
        if handle.is_null() {
            Err(Error::Backup {
                code: Error::code_for(to),
                message: Error::message_for(to),
            })?;
        }
        // Having proven that `handle` is not null, we could use new_unchecked here.
        // Choosing not to so that any mistakes are caught, rather than causing
        // undefined behaviour later on.
        Ok(NonNull::new(handle).expect("backup handle is non-null"))
    }

    fn step(handle: NonNull<sqlite3_backup>, pages: c_int) -> Result<(), Error> {
        let step = unsafe { sqlite3_backup_step(handle.as_ptr(), pages) };
        if SQLITE_BUSY == step {
            Err(Error::Backup {
                code: step,
                message: String::from("database busy"),
            })
        } else if SQLITE_LOCKED == step {
            Err(Error::Backup {
                code: step,
                message: String::from("database locked"),
            })
        } else {
            Ok(())
        }
    }

    fn finish(to: NonNull<sqlite3>, handle: NonNull<sqlite3_backup>) -> Result<(), Error> {
        let finished = unsafe { sqlite3_backup_finish(handle.as_ptr()) };
        if finished == SQLITE_OK {
            Ok(())
        } else {
            Err(Error::Backup {
                code: finished,
                message: Error::message_for(to),
            })
        }
    }
}

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error(transparent)]
    Sqlx(#[from] sqlx::Error),
    #[error("backup failed: {message} (code={code})")]
    Backup { code: c_int, message: String },
}

impl Error {
    fn code_for(handle: NonNull<sqlite3>) -> c_int {
        unsafe { sqlite3_extended_errcode(handle.as_ptr()) }
    }

    fn message_for(handle: NonNull<sqlite3>) -> String {
        let msg = unsafe { sqlite3_errmsg(handle.as_ptr()) };
        debug_assert!(!msg.is_null());
        from_utf8(unsafe { CStr::from_ptr(msg) }.to_bytes())
            .expect("error messages from sqlite are always utf-8")
            .to_owned()
    }
}