use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ clock::DateTime, repo::{ channel::Channel, login::{self, Login}, message, }, }; pub trait Provider { fn broadcast(&mut self) -> Broadcast; } impl<'c> Provider for Transaction<'c, Sqlite> { fn broadcast(&mut self) -> Broadcast { Broadcast(self) } } pub struct Broadcast<'t>(&'t mut SqliteConnection); #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Message { pub id: message::Id, pub sequence: Sequence, pub sender: Login, pub body: String, pub sent_at: DateTime, } impl<'c> Broadcast<'c> { pub async fn create( &mut self, sender: &Login, channel: &Channel, body: &str, sent_at: &DateTime, ) -> Result { let sequence = self.next_sequence_for(channel).await?; let id = message::Id::generate(); let message = sqlx::query!( r#" insert into message (id, channel, sequence, sender, body, sent_at) values ($1, $2, $3, $4, $5, $6) returning id as "id: message::Id", sequence as "sequence: Sequence", sender as "sender: login::Id", body, sent_at as "sent_at: DateTime" "#, id, channel.id, sequence, sender.id, body, sent_at, ) .map(|row| Message { id: row.id, sequence: row.sequence, sender: sender.clone(), body: row.body, sent_at: row.sent_at, }) .fetch_one(&mut *self.0) .await?; Ok(message) } async fn next_sequence_for(&mut self, channel: &Channel) -> Result { let Sequence(current) = sqlx::query_scalar!( r#" -- `max` never returns null, but sqlx can't detect that select max(sequence) as "sequence!: Sequence" from message where channel = $1 "#, channel.id, ) .fetch_one(&mut *self.0) .await?; Ok(Sequence(current + 1)) } pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> { sqlx::query!( r#" delete from message where sent_at < $1 "#, expire_at, ) .execute(&mut *self.0) .await?; Ok(()) } pub async fn replay( &mut self, channel: &Channel, resume_at: Option, ) -> Result, sqlx::Error> { let messages = sqlx::query!( r#" select message.id as "id: message::Id", sequence as "sequence: Sequence", login.id as "sender_id: login::Id", login.name as sender_name, message.body, message.sent_at as "sent_at: DateTime" from message join login on message.sender = login.id where channel = $1 and coalesce(sequence > $2, true) order by sequence asc "#, channel.id, resume_at, ) .map(|row| Message { id: row.id, sequence: row.sequence, sender: Login { id: row.sender_id, name: row.sender_name, }, body: row.body, sent_at: row.sent_at, }) .fetch_all(&mut *self.0) .await?; Ok(messages) } } #[derive( Debug, Eq, Ord, PartialEq, PartialOrd, Clone, Copy, serde::Serialize, serde::Deserialize, sqlx::Type, )] #[serde(transparent)] #[sqlx(transparent)] pub struct Sequence(i64);